如何让你的APP更智能|青训营笔记
这是我参与「第四届青训营 」笔记创作活动的的第16天
前言
随着当下的社会发展,我们的手机屏幕越来越大。我们的单手难以覆盖整个手机,所以当我们想要单手去点击屏幕另一侧的地方的时,就会感到较为困难。这时候我们就会想,这个按钮要是更靠近我们就好了。
那我们有办法让这些按钮自动的更靠近我们的操作手机的手么?
答案是有的,只要我们能判断出当前操作的手机是左手还是右手即可。左手按钮即可偏左;右手的话,按钮就偏右。
有了大致思路,开干!
方案
方案一
方案二
机器学习方案(我们的方案):
- 训练一个二分类的CNN神经网络模型来识别用户是左手 or 右手操作。
- 输入:用户在屏幕上的滑动轨迹
- 输出:左手 or 右手
From: 【Android 客户端专场 学习资料二】第四届字节跳动青训营 - 掘金 (juejin.cn)
实践
样本训练这里不做介绍,对应的模型直接采用该库的 ahcyd008/OperatingHandRecognition: 端智能左右手识别学习Android Demo + 模型训练 (github.com)
导入
由于方案二事采用深度学习的,所以我们需要引入深度学习对应于Android的框架。这些框架几乎都是几个巨头大厂的,我们这边使用的是 Google
的 tensorflow
lite
版本。它是适合于 Android 使用的 tensorflow
框架,我们主要是把正常的模型压缩,转化后,就能在 Android 中使用了。
其余的两个库一个是 Google
的 task
库,一个是 Google
的 guava
库。前者是对深度学习开启后台任务以及进行监控,而 guava
则是提供一个功能更加强大的 Java 封装库。
//app/build.gradle
dependencies {
// Task API
implementation "com.google.android.gms:play-services-tasks:17.2.1"
// tensorflow lite 依赖
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'
implementation("com.google.guava:guava:31.1-android")
}
//settings.gradle
pluginManagement {
repositories {
gradlePluginPortal()
google()
mavenCentral()
maven { url "https://jitpack.io" }
maven {
name 'ossrh-snapshot'
url 'https://oss.sonatype.org/content/repositories/snapshots'
}
}
}
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
jcenter()
maven {
name 'ossrh-snapshot'
url 'https://oss.sonatype.org/content/repositories/snapshots'
}
}
}
导入库的代码如上,记得在最后要加入 tensorflow
的仓库地址。
最后记得引入项目打包过的模型
模型的连接处理
class OperatingHandClassifier(private val context: Context) {
private var interpreter: Interpreter? = null
private var modelInputSize = 0
var isInitialized = false
private set
/** Executor to run inference task in the background */
private val executorService: ExecutorService = Executors.newSingleThreadScheduledExecutor()
private var hasInit = false
fun checkAndInit() {
if (hasInit) {
return
}
hasInit = true
val task = TaskCompletionSource<Void?>()
executorService.execute {
try {
initializeInterpreter()
task.setResult(null)
} catch (e: IOException) {
task.setException(e)
}
}
task.task.addOnFailureListener { e -> Log.e(TAG, "Error to setting up digit classifier.", e) }
}
@Throws(IOException::class)
private fun initializeInterpreter() {
// Load the TF Lite model
val assetManager = context.assets
val model = loadModelFile(assetManager)
// Initialize TF Lite Interpreter with NNAPI enabled
val options = Interpreter.Options()
// 测试发现 NNAPI 对 MaxPooling1D 有支持问题,如果遇到在手机端预测和python预测不准问题可以尝试关掉 NNAPI, 再check下
options.setUseNNAPI(true)
val interpreter = Interpreter(model, options)
// Read input shape from model file
val inputShape = interpreter.getInputTensor(0).shape()
val simpleCount = inputShape[1]
val tensorSize = inputShape[2]
modelInputSize = FLOAT_TYPE_SIZE * simpleCount * tensorSize * PIXEL_SIZE
val outputShape = interpreter.getOutputTensor(0).shape()
// Finish interpreter initialization
this.interpreter = interpreter
isInitialized = true
Log.d(TAG, "Initialized TFLite interpreter. inputShape:${Arrays.toString(inputShape)}, outputShape:${Arrays.toString(outputShape)}")
}
@Throws(IOException::class)
private fun loadModelFile(assetManager: AssetManager): ByteBuffer {
val fileDescriptor = assetManager.openFd(MODEL_FILE) // 使用全连接网络模型
// val fileDescriptor = assetManager.openFd(MODEL_CNN_FILE) // 使用卷积神经网络模型
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
private fun classify(pointList: JSONArray): ClassifierLabelResult {
if (!isInitialized) {
throw IllegalStateException("TF Lite Interpreter is not initialized yet.")
}
try {
// Preprocessing: resize the input
var startTime: Long = System.nanoTime()
val byteBuffer = convertFloatArrayToByteBuffer(pointList)
var elapsedTime = (System.nanoTime() - startTime) / 1000000
Log.d(TAG, "Preprocessing time = " + elapsedTime + "ms")
startTime = System.nanoTime()
val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
interpreter?.run(byteBuffer, result)
elapsedTime = (System.nanoTime() - startTime) / 1000000
Log.d(TAG, "Inference time = " + elapsedTime + "ms result=" + result[0].contentToString())
// return top 4
val output = result[0][0]
return if (output > 0.5f) {
ClassifierLabelResult(output, "right", labelRight)
} else {
ClassifierLabelResult(1.0f-output, "left", labelLeft)
}
} catch (e: Throwable) {
Log.e(TAG, "Inference error", e)
}
return ClassifierLabelResult(-1f, "unknown", labelUnknown)
}
fun classifyAsync(pointList: JSONArray): Task<ClassifierLabelResult> {
val task = TaskCompletionSource<ClassifierLabelResult>()
executorService.execute {
val result = classify(pointList)
task.setResult(result)
}
return task.task
}
fun close() {
executorService.execute {
interpreter?.close()
Log.d(TAG, "Closed TFLite interpreter.")
}
}
private fun convertFloatArrayToByteBuffer(pointList: JSONArray): ByteBuffer {
Log.d(TAG, "convertFloatArrayToByteBuffer pointList=$pointList")
val byteBuffer = ByteBuffer.allocateDirect(modelInputSize)
byteBuffer.order(ByteOrder.nativeOrder())
val step = pointList.length().toFloat() / sampleCount
for (i in 0 until sampleCount) {
val e = pointList[(i * step).toInt()] as JSONArray
for (j in 0 until tensorSize) {
val value = (e[j] as Number).toFloat() // x y w h density dtime
byteBuffer.putFloat(value)
}
}
return byteBuffer
}
companion object {
private const val TAG = "ClientAI#Classifier"
private const val MODEL_FILE = "mymodel.tflite"
private const val FLOAT_TYPE_SIZE = 4
private const val PIXEL_SIZE = 1
private const val OUTPUT_CLASSES_COUNT = 1
const val sampleCount = 9
const val tensorSize = 6
const val labelLeft = 0;
const val labelRight = 1;
const val labelUnknown = -1;
}
}
class ClassifierLabelResult(var score: Float, var label: String ,val labelInt: Int) {
override fun toString(): String {
val format = DecimalFormat("#.##")
return "$label score:${format.format(score)}"
}
}
class MotionEventTracker(var context: Context) {
companion object {
const val TAG = "ClientAI#tracker"
}
interface ITrackDataReadyListener {
fun onTrackDataReady(dataList: JSONArray)
}
private var width = 0
private var height = 0
private var density = 1f
private var listener: ITrackDataReadyListener? = null
fun checkAndInit(listener: ITrackDataReadyListener) {
this.listener = listener
val metric = context.resources.displayMetrics
width = min(metric.widthPixels, metric.heightPixels)
height = max(metric.widthPixels, metric.heightPixels)
density = metric.density
}
private var currentEvents: JSONArray? = null
private var currentDownTime = 0L
fun recordMotionEvent(ev: MotionEvent) {
if (ev.pointerCount > 1) {
currentEvents = null
return
}
if (ev.action == MotionEvent.ACTION_DOWN) {
currentEvents = JSONArray()
currentDownTime = ev.eventTime
}
if (currentEvents != null) {
if (ev.historySize > 0) {
for (i in 0 until ev.historySize) {
currentEvents?.put(buildPoint(ev.getHistoricalX(i), ev.getHistoricalY(i), ev.getHistoricalEventTime(i)))
}
}
currentEvents?.put(buildPoint(ev.x, ev.y, ev.eventTime))
}
if (ev.action == MotionEvent.ACTION_UP) {
currentEvents?.let {
if (it.length() >= 6) {
listener?.onTrackDataReady(it) // 触发预测
Log.i(TAG, "cache events, eventCount=${it.length()}, data=$it")
} else {
// 过滤点击和误触轨迹
Log.i(TAG, "skipped short events, eventCount=${it.length()}, data=$it")
}
}
currentEvents = null
}
}
private fun buildPoint(x: Float, y: Float, timestamp: Long): JSONArray {
val point = JSONArray()
point.put(x)
point.put(y)
point.put(width)
point.put(height)
point.put(density)
point.put(timestamp - currentDownTime)
return point
}
}
工具类
该工具类主要是对AI左右手进行辅助判断,因为模型的训练量不足,导致预测并不很准确。所以在这里使用一个队列工具类来获取最近三次的预测结果,之后再选择结果数最多的项来作为我们的预测结果。
public class QueueUtil {
private static final ArrayBlockingQueue<Integer> handList = new ArrayBlockingQueue<>(3);
private static int handLeft = 0;
private static int handRight = 0;
public static int getRecentHand(int labelHand) {
int poll;
if (!handList.offer(labelHand)) {
poll = handList.remove();
handList.offer(labelHand);
if (poll == labelHand) {
return compareRecentHand(handLeft, handRight);
} else {
switch (poll) {
case OperatingHandClassifier.labelRight:
handRight--;
break;
case OperatingHandClassifier.labelLeft:
handLeft--;
break;
default:
break;
}
}
}
switch (labelHand) {
case OperatingHandClassifier.labelRight:
handRight++;
break;
case OperatingHandClassifier.labelLeft:
handLeft++;
break;
default:
break;
}
return compareRecentHand(handLeft, handRight);
}
private static int compareRecentHand(int handLeft, int handRight) {
if (handLeft > handRight) {
return OperatingHandClassifier.labelLeft;
} else {
return OperatingHandClassifier.labelRight;
}
}
}
封装到BaseActivity
为方便我们的使用,这里将调用模型进行预测的相关代码封装到 BaseActivity
中,等我们需要使用的时候,在继承其的相应 Activity 中加上该注解即可调用该功能。
关于该功能的完整代码,可以查看我们的大项目 dyjcow/qxy_potato at feature_AIDialog_DYJ (github.com)
public abstract class BaseActivity<P extends BasePresenter<? extends BaseView>, VB extends ViewBinding>
extends AppCompatActivity implements BaseView, MotionEventTracker.ITrackDataReadyListener {
/**
* presenter层的引用
*/
protected P presenter;
private VB binding;
private OperatingHandClassifier classifier;
private MotionEventTracker tracker;
public int hand = 1;
/**
* {@inheritDoc}
* <p>
* Perform initialization of all fragments.
*
* @param savedInstanceState
*/
@Override
protected void onCreate(@Nullable Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (this.getClass().isAnnotationPresent(BindEventBus.class)) {
EventBus.getDefault().register(this);
}else if (this.getClass().isAnnotationPresent(InitAIHand.class)){
classifier = new OperatingHandClassifier(this);
classifier.checkAndInit();
tracker = new MotionEventTracker(this);
tracker.checkAndInit(this);
}
DisplayUtil.setCustomDensity(this);
UltimateBarX.statusBarOnly(this)
.light(true)
.transparent()
.apply();
//强制使用竖屏
setRequestedOrientation(ActivityInfo.SCREEN_ORIENTATION_PORTRAIT);
binding = ViewBindingUtil.inflateWithGeneric(this, getLayoutInflater());
setContentView(binding.getRoot());
presenter = createPresenter();
initView();
initData();
}
/**
* 初始化presenter,也是与Activity的绑定
*
* @return 返回new的Presenter层的值
*/
protected abstract P createPresenter();
/**
* 载入view的一些操作
*/
protected abstract void initView();
/**
* 载入数据操作
*/
protected abstract void initData();
/**
* 解除presenter与Activity的绑定
*/
@Override
protected void onDestroy() {
super.onDestroy();
if (this.getClass().isAnnotationPresent(BindEventBus.class)) {
EventBus.getDefault().unregister(this);
}else if (this.getClass().isAnnotationPresent(InitAIHand.class)){
classifier.close();
}
if (presenter != null) {
presenter.detachView();
}
}
@Override
public void showLoading() {
MyUtil.showLoading(this);
}
@Override
public void SuccessHideLoading() {
MyUtil.dismissSuccessLoading();
}
@Override
public void FailedHideLoading() {
MyUtil.dismissFailedLoading();
}
/**
* 错误
*
* @param bean 错误信息
*/
@Override
public void onErrorCode(BaseBean bean) {
ToastUtil.showToast(bean.msg);
}
public VB getBinding() {
return binding;
}
/**
* Called to process touch screen events. You can override this to
* intercept all touch screen events before they are dispatched to the
* window. Be sure to call this implementation for touch screen events
* that should be handled normally.
*
* @param ev The touch screen event.
* @return boolean Return true if this event was consumed.
*/
@Override public boolean dispatchTouchEvent(MotionEvent ev) {
if (tracker != null && ev != null){
tracker.recordMotionEvent(ev);
}
return super.dispatchTouchEvent(ev);
}
@Override public void onTrackDataReady(@NonNull JSONArray dataList) {
if (classifier != null){
classifier.classifyAsync(dataList).addOnSuccessListener(result -> {
hand = QueueUtil.getRecentHand(result.getLabelInt());
LogUtil.d(MotionEventTracker.TAG,result.getLabel());
}).addOnFailureListener(e -> LogUtil.e(MotionEventTracker.TAG,e.toString()));
}
}
}
编写弹窗代码
该弹窗实际上使用了开源库实现滚动效果
public class MyUtil{
...
public static void showOneOptionPicker(List<?> list, int handLabel) {
OptionsPickerBuilder builder = new OptionsPickerBuilder(ActivityUtil.getCurrentActivity(),
(options1, options2, options3, v) -> {
//返回的分别是三个级别的选中位置
BaseEvent<?> event = new BaseEvent<>(EventCode.SELECT_VERSION, list.get(options1));
EventBusUtil.sendEvent(event);
});
pvOptions = builder
.setDividerColor(Color.BLACK)
.setTextColorCenter(Color.BLACK) //设置选中项文字颜色
.setContentTextSize(19)
.setDividerColor(Color.GRAY)
.setDividerType(WheelView.DividerType.WRAP)
.isAlphaGradient(true)
.setLayoutRes(R.layout.layout_pickview_dialog, v -> {
//根据传入的左右手的值来选择对应的位置控件
TextView textView;
if (handLabel == OperatingHandClassifier.labelRight){
textView = v.findViewById(R.id.btnSubmitRight);
}else {
textView = v.findViewById(R.id.btnSubmitLeft);
}
//设置好控件后,让其显示
textView.setVisibility(View.VISIBLE);
textView.setOnClickListener(v1 -> {
pvOptions.returnData();
pvOptions.dismiss();
});
})
.build();
pvOptions.setPicker(list);//一级选择器
pvOptions.show();
}
...
}
下面我们看一下对应的布局代码
这里是按照原控件的布局代码来做的一个该着,下边的 WheelView
仍旧使用的是三个,未做优化改造。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:background="@drawable/shape_sheet_dialog_bg_white">
<TextView
android:id="@+id/btnSubmitLeft"
android:text="@string/pick_submit"
android:textStyle="bold"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:visibility="invisible"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintBottom_toTopOf="@+id/optionspicker"
android:layout_marginStart="20dp"
app:layout_constraintStart_toStartOf="parent"
android:layout_marginBottom="10dp"
android:layout_marginTop="15dp" />
<TextView
android:id="@+id/btnSubmitRight"
android:text="@string/pick_submit"
android:textStyle="bold"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:visibility="invisible"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintBottom_toTopOf="@+id/optionspicker"
android:layout_marginEnd="20dp"
app:layout_constraintEnd_toEndOf="parent"
android:layout_marginBottom="10dp"
android:layout_marginTop="15dp" />
<LinearLayout
android:id="@+id/optionspicker"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="@android:color/white"
android:gravity="center"
android:minHeight="180dp"
android:orientation="horizontal"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintTop_toBottomOf="@id/btnSubmitRight">
<com.contrarywind.view.WheelView
android:id="@+id/options1"
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="1" />
<com.contrarywind.view.WheelView
android:id="@+id/options2"
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="1" />
<com.contrarywind.view.WheelView
android:id="@+id/options3"
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="1" />
</LinearLayout>
</androidx.constraintlayout.widget.ConstraintLayout>
效果
当我们左手点击的时候,确认按钮在左边,右手点击的时候,确认按钮在右边
参考
ahcyd008/OperatingHandRecognition: 端智能左右手识别学习Android Demo + 模型训练 (github.com)