当前位置:   article > 正文

基于TensorFlow Object Detection API使用ssd_mobilenet_v1训练自己的数据集并转换为tflite模型(亲测可用)_lite-model_ssd_mobilenet_v1_1_metadata_2.tflite

lite-model_ssd_mobilenet_v1_1_metadata_2.tflite

1. 环境

OS: Ubuntu 16.04 x64
Anaconda: 4.6.12
Python: 3.6.8
TensorFlow(GPU版): 1.13.1
OpenCV: 3.4.1

2. 基础环境配置

Anaconda 下载地址: Anaconda-4.6.12-Linux

本文中安装位置为 /usr/local/anaconda3

修改默认的 python 版本为 3.6

conda install python=3.6
  • 1

安装 OpenCV 3.4.1

conda install opencv=3.4.1
  • 1

倘若安装 TensorFlow 1.13.1(GPU版),首先需要安装合适的NVIDIA的驱动,重启电脑以后执行以下命令可以看到需要下载对应版本的cudatoolkit、cudnn、tensorflow-gpu等依赖包,使用conda安装不需要我们自己去配环境,输入y就可以直接安装。

conda install tensorflow-gpu==1.13.1
  • 1

倘若安装 TensorFlow 1.13.1(CPU版),直接执行以下命令。

conda install tensorflow=1.13.1
  • 1

3. TensorFlow Models

下载地址: Github - TensorFlow Models
下载后得到一个 models-master.zip 文件,解压后移动到 /usr/local/anaconda3/lib/python3.6/site-packages/tensorflow文件夹下,并重命名为 models

unzip models-master.zip
mv models /usr/local/anaconda3/lib/python3.6/site-packages/tensorflow
  • 1
  • 2

进入 models/research目录,并编译 protobuf

cd /usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research
protoc object_detection/protos/*.proto --python_out=.
  • 1
  • 2

安装 object_detection 库

python setup.py build
python setup.py install
  • 1
  • 2

设置 PYTHONPATH

export PYTHONPATH=$PYTHONPATH:/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research
export PYTHONPATH=$PYTHONPATH:/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/slim
  • 1
  • 2

直接执行以上命令只会在当前终端生效,使用vim ~/.bashrc进行编辑,将以上命令写入 ~/.bashrc并执行如下命令可以永久保存

source ~/.bashrc
  • 1

测试 object_detection 库是否安装成功

python object_detection/builders/model_builder_test.py
  • 1

4. 训练

下载 VOC 2012 数据集: VOCtrainval_11-May-2012.tar

object_detection目录下创建目录 ssd_model,并解压数据集至 object_detection/ssd_model

mkdir ssd_model/
cd ssd_model
tar xvf VOCtrainval_11-May-2012.tar
  • 1
  • 2
  • 3

解压后主要目录结构为ssd_model/VOCdevkit/VOC2012,其中Annotations文件夹下存放标注的xml文件,JPEGImages文件夹下存放对应的jpg图片,ImagesSets/Main文件夹下最主要的是train.txt,trainval.txt,val.txt三个txt文件,分别表示训练集、训练+验证集、验证集。

返回 research目录,倘若需要训练自己的数据,为了尽可能少地修改代码,可以将自己的数据集修改跟下载的VOC2012数据集一样的格式,即例如:jpg图片命名为2008_000001.jpg,对应的xml标注文件为2008_000001.xml,修改ImagesSets/Main文件夹下train.txt,trainval.txt,val.txt三个txt文件中内容为自己的训练、验证集,并修改object_detection/dataset_tools/create_pascal_tf_record.py中如下代码。

examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
                                  ‘aeroplane_’+ FLAGS.set + '.txt')
#将以上代码修改为如下代码,这里只需要用到Main下面的train.txt,trainval.txt,val.txt
#原来的代码是用了官方下面的XX_train.txt等文件
examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main/'
                                  + FLAGS.set + '.txt')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

接着,./object_detection/data/pascal_label_map.pbtxt这个里面是根据我们开始标记的类型进行修改。最后执行以下命令。

如果只是使用下载的VOC2012数据集进行训练测试,则不需要上述自定义修改步骤,直接执行以下 train 和 val 脚本。

cd ../..
python ./object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=./object_detection/data/pascal_label_map.pbtxt --data_dir=object_detection/ssd_model/VOCdevkit/ --year=VOC2012 --set=train --output_path=./object_detection/ssd_model/pascal_train.record
python ./object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=./object_detection/data/pascal_label_map.pbtxt --data_dir=./object_detection/ssd_model/VOCdevkit/ --year=VOC2012 --set=val --output_path=./object_detection/ssd_model/pascal_val.record
  • 1
  • 2
  • 3

这两个脚本会在 ssd_model目录下生成 pascal_train.record 和 pascal_val.record 两个文件

复制配置文件,在此基础上修改,并训练数据

cp object_detection/data/pascal_label_map.pbtxt object_detection/ssd_model/
cp object_detection/samples/configs/ssd_mobilenet_v1_pets.config object_detection/ssd_model/
  • 1
  • 2

pascal_label_map.pbtxt 文件中保存了数据集中有哪些 label,此处应该修改为自己的类别。

将 ssd_mobilenet_v1_pets.config 中的 num_classes 改为 pascal_label_map.pbtxt 中列出的文件数量(VOC2012数据集是20,此处应该修改为自己所需识别的类数),并修改迭代次数 num_steps,并将配置文件末尾的路径按照如下格式修改

train_input_reader: {
  tf_record_input_reader {
    input_path: "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/ssd_model/pascal_train.record"
  }
  label_map_path: "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/ssd_model/pascal_label_map.pbtxt"
}

eval_input_reader: {
  tf_record_input_reader {
    input_path: "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/ssd_model/pascal_val.record"
  }
  label_map_path: "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/ssd_model/pascal_label_map.pbtxt"
  shuffle: false
  num_readers: 1
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

下载 ssd_mobilenet 至 ssd_model目录下,解压并重命名为 ssd_mobilenet

ssd_mobilenet: ssd_mobilenet_v1_coco_11_06_2017.tar.gz

tar zxvf ssd_mobilenet_v1_coco_11_06_2017.tar.gz
mv ssd_mobilenet_v1_coco_11_06_2017 ssd_mobilenet
  • 1
  • 2

将 ssd_mobilenet_v1_pets.config 中 fine_tune_checkpoint 修改为如下格式的路径

fine_tune_checkpoint: "/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection
  • 1

使用 train.py 脚本训练模型

注意:脚本可能位于 object_detection/或 object_detection/legacy/目录下

这里位于 object_detection/legacy目录

python ./object_detection/legacy/train.py --train_dir ./object_detection/legacy/train/ --pipeline_config_path ./object_detection/ssd_model/ssd_mobilenet_v1_pets.config
  • 1

训练输出的checkpoint文件、日志文件、模型文件以及pipeline.config文件都将生成在object_detection/legacy/train目录下

使用命令行来到日志文件的上级路径下,输入如下命令

tensorboard --logdir ./train
  • 1

接着打开浏览器,输入http://127.0.0.1:6006,即可使用tensorboard查看模型保存的变量,如loss等的变化情况。

5. 转换tflite模型

首先,通过object_detection目录下的export_tflite_ssd_graph.py将训练后的模型导出所需要的文件。model.ckpt-80000中80000需要与训练迭代轮数匹配。

python object_detection/export_tflite_ssd_graph.py --pipeline_config_path=/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/legacy/train/pipeline.config --trained_checkpoint_prefix=/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/legacy/train/model.ckpt-80000 --output_directory=/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/legacy/train --add_postprocessing_op=true
  • 1

运行后将在output_directory目录(这里设置为/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/legacy/train)生成tflite_graph.pb 和tflite_graph.pbtxt两个文件。

接着,下载TensorFlow 1.13.1源码,解压为tensorflow-1.13.1文件夹

然后,安装bazel工具,编译转换工具:下载地址及各系统安装方法
安装完成后开始编译转换工具:进入TensorFlow目录,以实际工程目录地址为主

cd tensorflow-1.13.1/   
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco
  • 1
  • 2
  • 3

最后,利用bazel生成tflite文件:

bazel run tensorflow/lite/toco:toco -- \
--input_file=/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/legacy/train/tflite_graph.pb \
--output_file=/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/models/research/object_detection/legacy/train/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--allow_custom_ops
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

参考资料:

  1. 搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型
  2. 使用TensorFlow Lite将ssd_mobilenet移植至安卓客户端
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/154252?site
推荐阅读
相关标签
  

闽ICP备14008679号