如何让你的APP更智能|青训营笔记


如何让你的APP更智能|青训营笔记

这是我参与「第四届青训营 」笔记创作活动的的第16天

前言

随着当下的社会发展,我们的手机屏幕越来越大。我们的单手难以覆盖整个手机,所以当我们想要单手去点击屏幕另一侧的地方的时,就会感到较为困难。这时候我们就会想,这个按钮要是更靠近我们就好了。

那我们有办法让这些按钮自动的更靠近我们的操作手机的手么?

答案是有的,只要我们能判断出当前操作的手机是左手还是右手即可。左手按钮即可偏左;右手的话,按钮就偏右。

有了大致思路,开干!

方案

From: 【Android 客户端专场 学习资料二】第四届字节跳动青训营 - 掘金 (juejin.cn)

实践

样本训练这里不做介绍,对应的模型直接采用该库的 ahcyd008/OperatingHandRecognition: 端智能左右手识别学习Android Demo + 模型训练 (github.com)

导入

由于方案二事采用深度学习的,所以我们需要引入深度学习对应于Android的框架。这些框架几乎都是几个巨头大厂的,我们这边使用的是 Googletensorflow lite 版本。它是适合于 Android 使用的 tensorflow 框架,我们主要是把正常的模型压缩,转化后,就能在 Android 中使用了。

其余的两个库一个是 Googletask 库,一个是 Googleguava 库。前者是对深度学习开启后台任务以及进行监控,而 guava 则是提供一个功能更加强大的 Java 封装库。

//app/build.gradle
dependencies {
	    // Task API
    implementation "com.google.android.gms:play-services-tasks:17.2.1"

    // tensorflow lite 依赖
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'

    implementation("com.google.guava:guava:31.1-android")
}
//settings.gradle
pluginManagement {
    repositories {
        gradlePluginPortal()
        google()
        mavenCentral()
        maven { url "https://jitpack.io" }
        maven {
            name 'ossrh-snapshot'
            url 'https://oss.sonatype.org/content/repositories/snapshots'
        }
    }
}
dependencyResolutionManagement {
    repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
    repositories {
        google()
        mavenCentral()
        jcenter()
        maven {
            name 'ossrh-snapshot'
            url 'https://oss.sonatype.org/content/repositories/snapshots'
        }
    }
}

导入库的代码如上,记得在最后要加入 tensorflow 的仓库地址。

最后记得引入项目打包过的模型

模型的连接处理

class OperatingHandClassifier(private val context: Context) {

    private var interpreter: Interpreter? = null
    private var modelInputSize = 0
    var isInitialized = false
        private set


    /** Executor to run inference task in the background */
    private val executorService: ExecutorService = Executors.newSingleThreadScheduledExecutor()
    private var hasInit = false
    fun checkAndInit() {
        if (hasInit) {
            return
        }
        hasInit = true
        val task = TaskCompletionSource<Void?>()
        executorService.execute {
            try {
                initializeInterpreter()
                task.setResult(null)
            } catch (e: IOException) {
                task.setException(e)
            }
        }
        task.task.addOnFailureListener { e -> Log.e(TAG, "Error to setting up digit classifier.", e) }
    }

    @Throws(IOException::class)
    private fun initializeInterpreter() {
        // Load the TF Lite model
        val assetManager = context.assets
        val model = loadModelFile(assetManager)

        // Initialize TF Lite Interpreter with NNAPI enabled
        val options = Interpreter.Options()
        // 测试发现 NNAPI 对 MaxPooling1D 有支持问题,如果遇到在手机端预测和python预测不准问题可以尝试关掉 NNAPI, 再check下
        options.setUseNNAPI(true)
        val interpreter = Interpreter(model, options)

        // Read input shape from model file
        val inputShape = interpreter.getInputTensor(0).shape()
        val simpleCount = inputShape[1]
        val tensorSize = inputShape[2]
        modelInputSize = FLOAT_TYPE_SIZE * simpleCount * tensorSize * PIXEL_SIZE
        val outputShape = interpreter.getOutputTensor(0).shape()

        // Finish interpreter initialization
        this.interpreter = interpreter
        isInitialized = true
        Log.d(TAG, "Initialized TFLite interpreter. inputShape:${Arrays.toString(inputShape)}, outputShape:${Arrays.toString(outputShape)}")
    }

    @Throws(IOException::class)
    private fun loadModelFile(assetManager: AssetManager): ByteBuffer {
        val fileDescriptor = assetManager.openFd(MODEL_FILE) // 使用全连接网络模型
        // val fileDescriptor = assetManager.openFd(MODEL_CNN_FILE) // 使用卷积神经网络模型
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

    private fun classify(pointList: JSONArray): ClassifierLabelResult {
        if (!isInitialized) {
            throw IllegalStateException("TF Lite Interpreter is not initialized yet.")
        }
        try {
            // Preprocessing: resize the input
            var startTime: Long = System.nanoTime()
            val byteBuffer = convertFloatArrayToByteBuffer(pointList)
            var elapsedTime = (System.nanoTime() - startTime) / 1000000
            Log.d(TAG, "Preprocessing time = " + elapsedTime + "ms")

            startTime = System.nanoTime()
            val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
            interpreter?.run(byteBuffer, result)
            elapsedTime = (System.nanoTime() - startTime) / 1000000
            Log.d(TAG, "Inference time = " + elapsedTime + "ms result=" + result[0].contentToString())

            // return top 4
            val output = result[0][0]
            return if (output > 0.5f) {
                ClassifierLabelResult(output, "right", labelRight)
            } else {
                ClassifierLabelResult(1.0f-output, "left", labelLeft)
            }
        } catch (e: Throwable) {
            Log.e(TAG, "Inference error", e)
        }
        return ClassifierLabelResult(-1f, "unknown", labelUnknown)
    }

    fun classifyAsync(pointList: JSONArray): Task<ClassifierLabelResult> {
        val task = TaskCompletionSource<ClassifierLabelResult>()
        executorService.execute {
            val result = classify(pointList)
            task.setResult(result)
        }
        return task.task
    }

    fun close() {
        executorService.execute {
            interpreter?.close()
            Log.d(TAG, "Closed TFLite interpreter.")
        }
    }

    private fun convertFloatArrayToByteBuffer(pointList: JSONArray): ByteBuffer {
        Log.d(TAG, "convertFloatArrayToByteBuffer pointList=$pointList")
        val byteBuffer = ByteBuffer.allocateDirect(modelInputSize)
        byteBuffer.order(ByteOrder.nativeOrder())
        val step = pointList.length().toFloat() / sampleCount
        for (i in 0 until sampleCount) {
            val e = pointList[(i * step).toInt()] as JSONArray
            for (j in 0 until tensorSize) {
                val value = (e[j] as Number).toFloat() // x y w h density dtime
                byteBuffer.putFloat(value)
            }
        }
        return byteBuffer
    }

    companion object {
        private const val TAG = "ClientAI#Classifier"
        private const val MODEL_FILE = "mymodel.tflite"
        private const val FLOAT_TYPE_SIZE = 4
        private const val PIXEL_SIZE = 1
        private const val OUTPUT_CLASSES_COUNT = 1

        const val sampleCount = 9
        const val tensorSize = 6

        const val labelLeft = 0;
        const val labelRight = 1;
        const val labelUnknown = -1;
    }
}

class ClassifierLabelResult(var score: Float, var label: String ,val labelInt: Int) {
    override fun toString(): String {
        val format = DecimalFormat("#.##")
        return "$label score:${format.format(score)}"
    }
}
class MotionEventTracker(var context: Context) {
    companion object {
        const val TAG = "ClientAI#tracker"
    }

    interface ITrackDataReadyListener {
        fun onTrackDataReady(dataList: JSONArray)
    }

    private var width = 0
    private var height = 0
    private var density = 1f
    private var listener: ITrackDataReadyListener? = null

    fun checkAndInit(listener: ITrackDataReadyListener) {
        this.listener = listener
        val metric = context.resources.displayMetrics
        width = min(metric.widthPixels, metric.heightPixels)
        height = max(metric.widthPixels, metric.heightPixels)
        density = metric.density
    }

    private var currentEvents: JSONArray? = null
    private var currentDownTime = 0L

    fun recordMotionEvent(ev: MotionEvent) {
        if (ev.pointerCount > 1) {
            currentEvents = null
            return
        }
        if (ev.action == MotionEvent.ACTION_DOWN) {
            currentEvents = JSONArray()
            currentDownTime = ev.eventTime
        }
        if (currentEvents != null) {
            if (ev.historySize > 0) {
                for (i in 0 until ev.historySize) {
                    currentEvents?.put(buildPoint(ev.getHistoricalX(i), ev.getHistoricalY(i), ev.getHistoricalEventTime(i)))
                }
            }
            currentEvents?.put(buildPoint(ev.x, ev.y, ev.eventTime))
        }
        if (ev.action == MotionEvent.ACTION_UP) {
            currentEvents?.let {
                if (it.length() >= 6) {
                    listener?.onTrackDataReady(it) // 触发预测
                    Log.i(TAG, "cache events, eventCount=${it.length()}, data=$it")
                } else {
                    // 过滤点击和误触轨迹
                    Log.i(TAG, "skipped short events, eventCount=${it.length()}, data=$it")
                }
            }
            currentEvents = null
        }
    }

    private fun buildPoint(x: Float, y: Float, timestamp: Long): JSONArray {
        val point = JSONArray()
        point.put(x)
        point.put(y)
        point.put(width)
        point.put(height)
        point.put(density)
        point.put(timestamp - currentDownTime)
        return point
    }


}

工具类

该工具类主要是对AI左右手进行辅助判断,因为模型的训练量不足,导致预测并不很准确。所以在这里使用一个队列工具类来获取最近三次的预测结果,之后再选择结果数最多的项来作为我们的预测结果。

public class QueueUtil {
	private static final ArrayBlockingQueue<Integer> handList = new ArrayBlockingQueue<>(3);
	private static int handLeft = 0;
	private static int handRight = 0;

	public static int getRecentHand(int labelHand) {
		int poll;
		if (!handList.offer(labelHand)) {
			poll = handList.remove();
			handList.offer(labelHand);
			if (poll == labelHand) {
				return compareRecentHand(handLeft, handRight);
			} else {
				switch (poll) {
				case OperatingHandClassifier.labelRight:
					handRight--;
					break;
				case OperatingHandClassifier.labelLeft:
					handLeft--;
					break;
				default:
					break;
				}
			}
		}

		switch (labelHand) {
		case OperatingHandClassifier.labelRight:
			handRight++;
			break;
		case OperatingHandClassifier.labelLeft:
			handLeft++;
			break;
		default:
			break;
		}
		return compareRecentHand(handLeft, handRight);
	}

	private static int compareRecentHand(int handLeft, int handRight) {
		if (handLeft > handRight) {
			return OperatingHandClassifier.labelLeft;
		} else {
			return OperatingHandClassifier.labelRight;
		}
	}
}

封装到BaseActivity

为方便我们的使用,这里将调用模型进行预测的相关代码封装到 BaseActivity 中,等我们需要使用的时候,在继承其的相应 Activity 中加上该注解即可调用该功能。

关于该功能的完整代码,可以查看我们的大项目 dyjcow/qxy_potato at feature_AIDialog_DYJ (github.com)

public abstract class BaseActivity<P extends BasePresenter<? extends BaseView>, VB extends ViewBinding>
        extends AppCompatActivity implements BaseView, MotionEventTracker.ITrackDataReadyListener {

    /**
     * presenter层的引用
     */
    protected P presenter;

    private VB binding;

    private OperatingHandClassifier classifier;

    private MotionEventTracker tracker;

    public int hand = 1;



    /**
     * {@inheritDoc}
     * <p>
     * Perform initialization of all fragments.
     *
     * @param savedInstanceState
     */
    @Override
    protected void onCreate(@Nullable Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        if (this.getClass().isAnnotationPresent(BindEventBus.class)) {
            EventBus.getDefault().register(this);
        }else if (this.getClass().isAnnotationPresent(InitAIHand.class)){
            classifier = new OperatingHandClassifier(this);
            classifier.checkAndInit();
            tracker = new MotionEventTracker(this);
            tracker.checkAndInit(this);
        }
        DisplayUtil.setCustomDensity(this);
        UltimateBarX.statusBarOnly(this)
                .light(true)
                .transparent()
                .apply();
        //强制使用竖屏
        setRequestedOrientation(ActivityInfo.SCREEN_ORIENTATION_PORTRAIT);
        binding = ViewBindingUtil.inflateWithGeneric(this, getLayoutInflater());
        setContentView(binding.getRoot());
        presenter = createPresenter();

        initView();
        initData();
    }

    /**
     * 初始化presenter,也是与Activity的绑定
     *
     * @return 返回new的Presenter层的值
     */
    protected abstract P createPresenter();


    /**
     * 载入view的一些操作
     */
    protected abstract void initView();


    /**
     * 载入数据操作
     */
    protected abstract void initData();

    /**
     * 解除presenter与Activity的绑定
     */
    @Override
    protected void onDestroy() {
        super.onDestroy();
        if (this.getClass().isAnnotationPresent(BindEventBus.class)) {
            EventBus.getDefault().unregister(this);
        }else if (this.getClass().isAnnotationPresent(InitAIHand.class)){
            classifier.close();
        }
        if (presenter != null) {
            presenter.detachView();
        }
    }

    @Override
    public void showLoading() {
        MyUtil.showLoading(this);
    }

    @Override
    public void SuccessHideLoading() {
        MyUtil.dismissSuccessLoading();
    }

    @Override
    public void FailedHideLoading() {
        MyUtil.dismissFailedLoading();
    }

    /**
     * 错误
     *
     * @param bean 错误信息
     */
    @Override
    public void onErrorCode(BaseBean bean) {
        ToastUtil.showToast(bean.msg);
    }

    public VB getBinding() {
        return binding;
    }

    /**
     * Called to process touch screen events.  You can override this to
     * intercept all touch screen events before they are dispatched to the
     * window.  Be sure to call this implementation for touch screen events
     * that should be handled normally.
     *
     * @param ev The touch screen event.
     * @return boolean Return true if this event was consumed.
     */
    @Override public boolean dispatchTouchEvent(MotionEvent ev) {
        if (tracker != null && ev != null){
            tracker.recordMotionEvent(ev);
        }
        return super.dispatchTouchEvent(ev);
    }

    @Override public void onTrackDataReady(@NonNull JSONArray dataList) {
        if (classifier != null){
            classifier.classifyAsync(dataList).addOnSuccessListener(result -> {
                hand = QueueUtil.getRecentHand(result.getLabelInt());
                LogUtil.d(MotionEventTracker.TAG,result.getLabel());
            }).addOnFailureListener(e -> LogUtil.e(MotionEventTracker.TAG,e.toString()));
        }
    }

}

编写弹窗代码

该弹窗实际上使用了开源库实现滚动效果

Bigkoo/Android-PickerView: This is a picker view for android , support linkage effect, timepicker and optionspicker.(时间选择器、省市区三级联动) (github.com)

public class MyUtil{
    ...
	public static void showOneOptionPicker(List<?> list, int handLabel) {

        OptionsPickerBuilder builder = new OptionsPickerBuilder(ActivityUtil.getCurrentActivity(),
                (options1, options2, options3, v) -> {
                    //返回的分别是三个级别的选中位置
                    BaseEvent<?> event = new BaseEvent<>(EventCode.SELECT_VERSION, list.get(options1));
                    EventBusUtil.sendEvent(event);

                });
        pvOptions = builder
                .setDividerColor(Color.BLACK)
                .setTextColorCenter(Color.BLACK) //设置选中项文字颜色
                .setContentTextSize(19)
                .setDividerColor(Color.GRAY)
                .setDividerType(WheelView.DividerType.WRAP)
                .isAlphaGradient(true)
                .setLayoutRes(R.layout.layout_pickview_dialog, v -> {
                    //根据传入的左右手的值来选择对应的位置控件
                    TextView textView;
                    if (handLabel == OperatingHandClassifier.labelRight){
                        textView = v.findViewById(R.id.btnSubmitRight);
                    }else {
                        textView = v.findViewById(R.id.btnSubmitLeft);
                    }
                    //设置好控件后,让其显示
                    textView.setVisibility(View.VISIBLE);
                    textView.setOnClickListener(v1 -> {
                        pvOptions.returnData();
                        pvOptions.dismiss();
                    });
                })
                .build();
        pvOptions.setPicker(list);//一级选择器
        pvOptions.show();
    }
    ...
}

下面我们看一下对应的布局代码

这里是按照原控件的布局代码来做的一个该着,下边的 WheelView 仍旧使用的是三个,未做优化改造。

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
        xmlns:app="http://schemas.android.com/apk/res-auto"
        xmlns:tools="http://schemas.android.com/tools"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:background="@drawable/shape_sheet_dialog_bg_white">

    <TextView
            android:id="@+id/btnSubmitLeft"
            android:text="@string/pick_submit"
            android:textStyle="bold"
            android:layout_width="wrap_content"
            android:layout_height="wrap_content"
            android:visibility="invisible"
            app:layout_constraintTop_toTopOf="parent"
            app:layout_constraintBottom_toTopOf="@+id/optionspicker"
            android:layout_marginStart="20dp"
            app:layout_constraintStart_toStartOf="parent"
            android:layout_marginBottom="10dp"
            android:layout_marginTop="15dp" />

    <TextView
            android:id="@+id/btnSubmitRight"
            android:text="@string/pick_submit"
            android:textStyle="bold"
            android:layout_width="wrap_content"
            android:layout_height="wrap_content"
            android:visibility="invisible"
            app:layout_constraintTop_toTopOf="parent"
            app:layout_constraintBottom_toTopOf="@+id/optionspicker"
            android:layout_marginEnd="20dp"
            app:layout_constraintEnd_toEndOf="parent"
            android:layout_marginBottom="10dp"
            android:layout_marginTop="15dp" />


    <LinearLayout
            android:id="@+id/optionspicker"
            android:layout_width="match_parent"
            android:layout_height="match_parent"
            android:background="@android:color/white"
            android:gravity="center"
            android:minHeight="180dp"
            android:orientation="horizontal"
            app:layout_constraintBottom_toBottomOf="parent"
            app:layout_constraintTop_toBottomOf="@id/btnSubmitRight">

        <com.contrarywind.view.WheelView
                android:id="@+id/options1"
                android:layout_width="0dp"
                android:layout_height="match_parent"
                android:layout_weight="1" />

        <com.contrarywind.view.WheelView
                android:id="@+id/options2"
                android:layout_width="0dp"
                android:layout_height="match_parent"
                android:layout_weight="1" />

        <com.contrarywind.view.WheelView
                android:id="@+id/options3"
                android:layout_width="0dp"
                android:layout_height="match_parent"
                android:layout_weight="1" />

    </LinearLayout>

</androidx.constraintlayout.widget.ConstraintLayout>

效果

当我们左手点击的时候,确认按钮在左边,右手点击的时候,确认按钮在右边

参考

ahcyd008/OperatingHandRecognition: 端智能左右手识别学习Android Demo + 模型训练 (github.com)

【Android 客户端专场 学习资料二】第四届字节跳动青训营 - 掘金 (juejin.cn)

Recognizing the Operating Hand and the Hand-Changing Process for User Interface Adjustment on Smartphones - PMC (nih.gov)


文章作者: DYJ
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 DYJ !
评论
  目录