赞
踩
@Tensorflow 1.13训练模型.pb文件转换成Tensorflowlite可以使用的.tflite文件过程记录
之前一直通过1.13版本的TensorflowGpu训练模型,使用范围局限在电脑端(例如opencv调用模型等等)。最近的一个项目需要在移动端部署,将训练好的.pb模型可以成功移植到安卓移动端,但是出现了一个老生常谈的问题,就是无法迅速连续识别,这主要是因为移动端和PC端硬件的差异,为了解决这一问题,决定投入Tensorflowlite的怀抱,实现在移动端的迅捷目标检测。
环境介绍:
Ubuntu 16.04
Tensorflow 1.13.1(含Gpu)
移动端 Honor V20
Android Studio 3.5.3
算法 MobileNet-ssd-V1
tensorflow训练开始后,会随时间推移生成不同训练步数的记录文件,如下:
如果不考虑后续的Tflite转换,那么只需要调用object_detection的export_inference_graph.py,输入以下类似命令来生成.pb文件即可,文件使用方法不赘述。
python export_inference_graph.py --input_type image_tensor --pipeline_co
nfig_path training/ssd_mobilenet_v1_XXX.config --trained_checkpoint_prefix training/model.ckpt-XXXX --output_directory detection
注意。如果后续要设计tflite转换,那么需要调用的文件是object_detection下的export_tflite_ssd_graph.py,命令与上类似:
python export_tflite_ssd_graph.py --input_type image_tensor --pipeline_co
nfig_path training/ssd_mobilenet_v1_XXX.config --trained_checkpoint_prefix training/model.ckpt-XXXX --output_directory detection
执行export_tflite_ssd_graph.py后,输出文件夹内容大致如下:
包含tflite_graph.pb和tflite_graph.pbtxt两个文件。这就是我们需要使用的.pb文件
注意!很多文章都有详细的Tensorflow Bazel配置过程详解,我们的配置过程类似,但是Bazel build过程,我们只需要以下一个步骤即可,
bazel build tensorflow/tools/graph_transforms:summarize_graph
build完成后,可以通过命令获取.pb模型输入输出节点array名称和相关矩阵参数,在下面的pb_to_tflite.py程序中填写使用。
注意!!!
bazel build tensorflow/tools/graph_transforms:summarize_graph此处有个BUG,输出的信息中,output_array我这里只显示一个TFLite_Detection_PostProcess,
其实正确的应该是
‘TFLite_Detection_PostProcess’,‘TFLite_Detection_PostProcess:1’,‘TFLite_Detection_PostProcess:2’,'TFLite_Detection_PostProcess:3’这四个,
分别代表的含义是
detection_boxes, detection_classes, detection_scores, and num_detections,
在后面的pb_to_tflite.py中也要写这四个才能保证.tflite模型在移动端的正常使用!!!
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=tflite_graph.pb
至于其他文章经常提及的两个步骤:
bazel build tensorflow/lite/toco:toco
bazel build tensorflow/python/tools:freeze_graph
其实并不是必要的,原因:
首先
bazel build tensorflow/lite/toco:toco
的一个目的是,在后续转换tflite文件中,可通过
bazel run --config=opt tensorflow/lite/toco:toco -- \
--input_file=$OUTPUT_DIR/tflite_graph.pb \
--output_file=$OUTPUT_DIR/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--allow_custom_ops
生成.tflite文件,但是我在使用过程,总是出现各种莫名其妙的错误,可能是我太菜了。。,所以这一步我用一个pb_to_tflite.py程序替代,如下
# -*- coding:utf-8 -*- ##python 1 import tensorflow as tf in_path = "tflite_graph.pb" #out_path = "tflite_graph.tflite" # out_path = "./model/quantize_frozen_graph.tflite" # 模型输入节点 input_tensor_name = ["normalized_input_image_tensor"] input_tensor_shape = {"normalized_input_image_tensor":[1,300,300,3]} # 模型输出节点 classes_tensor_name = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'] converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path, input_tensor_name, classes_tensor_name, input_tensor_shape) converter.allow_custom_ops=True #converter.post_training_quantize = True tflite_model = converter.convert() open("4output_detect.tflite", "wb").write(tflite_model)
生成一个自定义名称的.tflite文件。
而
bazel build tensorflow/python/tools:freeze_graph
在前面的export_tflite_ssd_graph.py执行后,其实已经freeze模型了,所以这行build其实也没用上。
converter.allow_custom_ops=True
这一行很重要,目的是保存一个在tflite中一些无法转换的原模型参数,不填加的话,十有八九回报错。
#converter.post_training_quantize = True
这一行的目的是决定是否输出量化的tflite模型
注意!量化后精度必然会有一定程度的降低,大小将缩小至1/4。
将第二步中的tflite文件放在android/app/src/main/assets中,并在同一目录新建一个txt文件存放物体标签,格式类似于:
???
label1
label2
...
???
???
在官方提供的项目的gradle中,comment out这一句
// apply from:'download_model.gradle'
避免下载官方tflite模型覆盖自己的模型。
以下三行是调用移植是否成功的关键,
private static final boolean TF_OD_API_IS_QUANTIZED = true;
private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/labels_list.txt";
第一行根据自己的程序是否量化自行填写,量化写true,反之false
第二行和第三行则填入自己的文件,随后Run至手机即可
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。