赞
踩
人工智能火了,tensorflow 也火了,Google推出移动版的TensorFlow Lite,作为一个Android开发应该熟悉一下。今天的目标就是能够在移动端也能进行部署深度学习框架,既然Android也能运行TensorFlow 为何不尝试一下,这是程序员们的通病,干就完了。
本次开发环境为TensorFlow 2.1+python 3.7+Android studio 3.6.1
这里我只是简单的说一下,毕竟今天的目标不是搭建环境,而是如何在Android上部署TensorFlow。你直接下载anaconda版的python,用命令安装,有坑,不要用pip命令,有两方面有原因,
解决办法:采用conda install python 这个命令。如何你对配置环境确实毫无头绪,你可以参考TensorFlow官网,或者百度一大堆教你如何搭建环境的。
今天我用了一个很简单的例子,用TensorFlow 拟合一个函数,y=ax+b,给出x,y,通过TensorFlow 算出a,b;看一下我拟合效果
其实很简单,代码如下:
import tensorflow as tf
# 创建一个简单的 Keras 模型。
x = [-1, 0, 1, 2, 3, 4]
y = [-3, -1, 1, 3, 5, 7]
model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=500)
print(model.predict([1, 3, 7]))
记住([1, 3, 7])得到的结果,我们会在Android也运行这组数据,看结果是否一样。
上面是用Keras写例子,既然python 的TensorFlow 已经写好了,那如何在Android上用呢,就需要用到转换,将python代码转成Android 可以用的。
export_dir = 'saved_model/test'
tf.saved_model.save(model, export_dir)
#转换模型。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file.write_bytes(tflite_model)
这个tf.saved_model.save(model, export_dir) 可能会报错,无法创建目录,你可以手动来创建目录。
在对应的目录下,找到一个.tflite的文件,这个文件就是我们在Android要调用。
Android 环境需要哪些依赖?
TensorFlow lite文件放在哪?
如何调用?
数据如何输入?
结果在哪?
下面我都会一一说明。
在build.gradle中依赖 implementation ‘org.tensorflow:tensorflow-lite:0.0.0-nightly’,这个必不可少的。
添加arm支持
defaultConfig {
……
ndk {
abiFilters 'armeabi-v7a', 'arm64-v8a'
}
}
android{
……
aaptOptions {
noCompress "tflite" //表示不让aapt压缩的文件后缀
}
}
public class TFLiteLoader { private static Context mContext; Interpreter mInterpreter; private static TFLiteLoader instance; public static TFLiteLoader newInstance(Context context) { mContext = context; if (instance == null) { instance = new TFLiteLoader(); } return instance; } Interpreter get() { try { if (Objects.isNull(mInterpreter)) mInterpreter = new Interpreter(loadModelFile(mContext)); } catch (IOException e) { e.printStackTrace(); } return mInterpreter; } // 获取文件 private MappedByteBuffer loadModelFile(Context context) throws IOException { AssetFileDescriptor fileDescriptor = context.getAssets().openFd("model.tflite"); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } }
输入输入和输出:
float[][] input = new float[][]{{1, 3, 7}};
float[][] output = new float[3][1];
TFLiteLoader.newInstance(getApplicationContext()).get().run(input, output);
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 1; j++) {
Log.i(TAG, output[i][j] + "");
}
}
运行结果如下:
这个结果和python版的一样的,到此我们算是成功在Android上部署了TensorFlow lite。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。