当前位置:   article > 正文

Tensorflow 1.13训练模型.pb文件转换成Tensorflowlite可以使用的.tflite文件过程记录

Tensorflow 1.13训练模型.pb文件转换成Tensorflowlite可以使用的.tflite文件过程记录

@Tensorflow 1.13训练模型.pb文件转换成Tensorflowlite可以使用的.tflite文件过程记录

前言

之前一直通过1.13版本的TensorflowGpu训练模型,使用范围局限在电脑端(例如opencv调用模型等等)。最近的一个项目需要在移动端部署,将训练好的.pb模型可以成功移植到安卓移动端,但是出现了一个老生常谈的问题,就是无法迅速连续识别,这主要是因为移动端和PC端硬件的差异,为了解决这一问题,决定投入Tensorflowlite的怀抱,实现在移动端的迅捷目标检测。
环境介绍:
Ubuntu 16.04
Tensorflow 1.13.1(含Gpu)
移动端 Honor V20
Android Studio 3.5.3
算法 MobileNet-ssd-V1

一、.pb文件的生成

tensorflow训练开始后,会随时间推移生成不同训练步数的记录文件,如下:
在这里插入图片描述如果不考虑后续的Tflite转换,那么只需要调用object_detection的export_inference_graph.py,输入以下类似命令来生成.pb文件即可,文件使用方法不赘述。

python export_inference_graph.py --input_type image_tensor --pipeline_co
nfig_path training/ssd_mobilenet_v1_XXX.config --trained_checkpoint_prefix training/model.ckpt-XXXX --output_directory detection  
  • 1
  • 2

注意。如果后续要设计tflite转换,那么需要调用的文件是object_detection下的export_tflite_ssd_graph.py,命令与上类似:

python export_tflite_ssd_graph.py --input_type image_tensor --pipeline_co
nfig_path training/ssd_mobilenet_v1_XXX.config --trained_checkpoint_prefix training/model.ckpt-XXXX --output_directory detection   
  • 1
  • 2

执行export_tflite_ssd_graph.py后,输出文件夹内容大致如下:
在这里插入图片描述包含tflite_graph.pb和tflite_graph.pbtxt两个文件。这就是我们需要使用的.pb文件

二、.pb转换.tflite

(1)Bazel配置

注意!很多文章都有详细的Tensorflow Bazel配置过程详解,我们的配置过程类似,但是Bazel build过程,我们只需要以下一个步骤即可,

bazel build tensorflow/tools/graph_transforms:summarize_graph
  • 1

build完成后,可以通过命令获取.pb模型输入输出节点array名称和相关矩阵参数,在下面的pb_to_tflite.py程序中填写使用。
注意!!!
bazel build tensorflow/tools/graph_transforms:summarize_graph此处有个BUG,输出的信息中,output_array我这里只显示一个TFLite_Detection_PostProcess,
其实正确的应该是
‘TFLite_Detection_PostProcess’,‘TFLite_Detection_PostProcess:1’,‘TFLite_Detection_PostProcess:2’,'TFLite_Detection_PostProcess:3’这四个,
分别代表的含义是
detection_boxes, detection_classes, detection_scores, and num_detections,
在后面的pb_to_tflite.py中也要写这四个才能保证.tflite模型在移动端的正常使用!!!

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=tflite_graph.pb 
  • 1

至于其他文章经常提及的两个步骤:

bazel build tensorflow/lite/toco:toco
bazel build tensorflow/python/tools:freeze_graph
  • 1
  • 2

其实并不是必要的,原因:

首先

bazel build tensorflow/lite/toco:toco
  • 1

的一个目的是,在后续转换tflite文件中,可通过

bazel run --config=opt tensorflow/lite/toco:toco -- \
 
--input_file=$OUTPUT_DIR/tflite_graph.pb \
 
--output_file=$OUTPUT_DIR/detect.tflite \
 
--input_shapes=1,300,300,3 \
 
--input_arrays=normalized_input_image_tensor \
 
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'  \
 
--inference_type=FLOAT \
 
--allow_custom_ops
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

生成.tflite文件,但是我在使用过程,总是出现各种莫名其妙的错误,可能是我太菜了。。,所以这一步我用一个pb_to_tflite.py程序替代,如下

# -*- coding:utf-8 -*-
##python 1
import tensorflow as tf

in_path = "tflite_graph.pb"
#out_path = "tflite_graph.tflite"
# out_path = "./model/quantize_frozen_graph.tflite"

# 模型输入节点
input_tensor_name = ["normalized_input_image_tensor"]
input_tensor_shape = {"normalized_input_image_tensor":[1,300,300,3]}
# 模型输出节点
classes_tensor_name = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']

converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path,
                                            input_tensor_name, classes_tensor_name,
                                            input_tensor_shape)

converter.allow_custom_ops=True
#converter.post_training_quantize = True
tflite_model = converter.convert()

open("4output_detect.tflite", "wb").write(tflite_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

生成一个自定义名称的.tflite文件。

bazel build tensorflow/python/tools:freeze_graph
  • 1

在前面的export_tflite_ssd_graph.py执行后,其实已经freeze模型了,所以这行build其实也没用上。

(2)pb_to_tflite.py重要语句介绍

converter.allow_custom_ops=True
  • 1

这一行很重要,目的是保存一个在tflite中一些无法转换的原模型参数,不填加的话,十有八九回报错。

#converter.post_training_quantize = True
  • 1

这一行的目的是决定是否输出量化的tflite模型
注意!量化后精度必然会有一定程度的降低,大小将缩小至1/4。

三、Android Studio调用

将第二步中的tflite文件放在android/app/src/main/assets中,并在同一目录新建一个txt文件存放物体标签,格式类似于:

???
label1
label2
...
???
???
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在官方提供的项目的gradle中,comment out这一句

// apply from:'download_model.gradle'
  • 1

避免下载官方tflite模型覆盖自己的模型。
以下三行是调用移植是否成功的关键,

  private static final boolean TF_OD_API_IS_QUANTIZED = true;
  private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
  private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/labels_list.txt";
  • 1
  • 2
  • 3

第一行根据自己的程序是否量化自行填写,量化写true,反之false
第二行和第三行则填入自己的文件,随后Run至手机即可

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/154309
推荐阅读
相关标签
  

闽ICP备14008679号