赞
踩
很多博客教程都需要安装bazel编译工具,并且编译tensorflow的源码,过程繁琐。本篇博文教你如何用几行代码就实现.pb到.tflite文件转换。
.pb到.tflite文件转换代码先贴上来,随后做详细讲解。
- import tensorflow as tf
-
- in_path = r'.\yolov3_coco.pb'
- out_path = r'.\yolov3_coco.tflite'
-
- input_arrays = ["input/input_data"]
- input_shapes = {"input/input_data" :[1, 416, 416, 3]}
- output_arrays = ["pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]
-
- converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path, input_arrays, output_arrays, input_shapes)
- tflite_model = converter.convert()
- open("myModel.tflite", "wb").write(tflite_model)
这里需要修改的参数有:in_path, out_path, input_arrays, output_arrays, input_shapes.
in_path:输入的.pb文件路径 out_path:输出的.tflite文件路径
input_arrays:输入的节点名称 output_arrays:输出的节点名称
input_shapes:输入的节点形状
设置路径倒是难不倒大家,但是节点名称和节点形状的参数哪里找呢?以生成yolov3的.pb文件为例说明。
yolov3的代码中,.pb文件是通过convert_weight.py文件和freeze_graph.py文件。俗话说得好,解铃还须系铃人,参数就可以从中找到。
freeze_graph.py文件可以找到input_arrays以及output_arrays参数
convert_weight.py文件找到了input_shapes
但你复制代码后,你的tf.contrib.lite.TFLiteConverter这条语句会报错,因为这条语句是针对tensorflow 1.12极其以上版本。
不同tensorflow版本对应的转换语句如下。
在tensorflow 2.0版本中,转化语句变成了命令行, tflite_convert。在终端执行如下命令,即可查看其具体使用方法。
tflite_convert -h
我在Python Terminal当中的执行结果。
如果本篇文章有任何可以改进的地方,欢迎大家在下方留言:)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。