赞
踩
1.安装tensorflow(version>=1.4.0)
2.部署tensorflow models
- 在这里下载
- 解压并安装
- 解压后重命名为models复制到tensorflow/目录下
- 在linux下
- 进入tensorflow/models/research/目录,运行protoc object_detection/protos/*.proto --python_out=.
- 在~/.bashrc file.中添加slim和models/research路径
export PYTHONPATH=$PYTHONPATH:/path/to/slim:/path/to/research
- 在windows下
- 下载protoc-3.3.0-win32.zip(version==3.3,已知3.5版本会报错)
- 解压后将protoc.exe放入C:\Windows下
- 在tensorflow/models/research/打开powershell,运行protoc object_detection/protos/*.proto --python_out=.
3.训练数据准备(标记分类的图片)
- 安装labelImg 用来手动标注图片 ,图片需要是png或者jpg格式
- 标注信息会被保存为xml文件,使用 这个脚本 将所有xml文件转换为一个csv文件(xml文件路径识别在29行,根据情况自己修改)
- 把生成的csv文件分成训练集和测试集
4.生成TFRecord文件
- 使用 这个脚本 将两个csv文件生成出两个TFRecord文件(训练自己的模型,必须使用TFRecord格式文件。图片路径识别在86行,根据情况自己修改)
5.创建label map文件
id需要从1开始,class-N便是自己需要识别的物体类别名,文件后缀为.pbtxt
item{
id:1
name: 'class-1'
}
item{
id:2
name: 'class-2'
}
6.下载模型并配置文件
- 下载一个模型(文件后缀.tar.gz)
- 修改对应的训练pipline配置文件
- 查找文件中的PATH_TO_BE_CONFIGURED字段,并做相应修改
- num_classes 改为你模型中包含类别的数量
- fine_tune_checkpoint 解压.tar.gz文件后的路径 + /model.ckpt
- from_detection_checkpoint:true
- train_input_reader
- input_path 由train.csv生成的record格式训练数据
- label_map_path 第5步创建的pbtxt文件路径
- eval_input_reader
- input_path 由test.csv生成的record格式训练数据
- label_map_path 第5步创建的pbtxt文件路径
7. 训练模型
- 进入tensorflow/models/research/目录,运行
python object_detection/train.py --logtostderr --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} //第六步中修改的pipline配置文件路径// --train_dir=${PATH_TO_TRAIN_DIR} //生成的模型保存路径//
8.导出模型
- 在第7步中,--train_dir指向的路径中会生成一系列训练中自动保存的checkpoint,一个checkpoint由三个文件组成,后缀分别是.data-00000-of-00001 .index和.meta,任然在第7步的路径中,运行
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path ${PIPELINE_CONFIG_PATH} //第六步中修改的pipline配置文件路径\--trained_checkpoint_prefix ${TRAIN_PATH} //上述的一个checkpoint,例如model.ckpt-112254 \ --output_directory ${OUTPUT_PATH} //输出模型文件的路径//
9.使用新模型识别图片
调用predict.py
首先导入包
import time import cv2 import numpy as np import tensorflow as tf import pandas as pd import math import os from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util
然后定义类和函数
class TOD(object): def __init__(self): self.PATH_TO_CKPT = r'D:/xiangchuang/new_train_model/result/frozen_inference_graph.pb' self.PATH_TO_LABELS = r'D:/xiangchuang/pig.pbtxt' self.NUM_CLASSES = 1 self.detection_graph = self._load_model() self.category_index = self._load_label_map() def _load_model(self): global detection_graph detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') return detection_graph def _load_label_map(self): label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) return category_index def detect(self, image): image_np_expanded = np.expand_dims(image, axis=0) image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0') boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0') scores = self.detection_graph.get_tensor_by_name('detection_scores:0') classes = self.detection_graph.get_tensor_by_name('detection_classes:0') num_detections = self.detection_graph.get_tensor_by_name('num_detections:0') # Actual detection. (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict={image_tensor: image_np_expanded}) # Visualization of the results of a detection. vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), self.category_index, use_normalized_coordinates=True, line_thickness=8) cv2.namedWindow("detection", cv2.WINDOW_NORMAL) cv2.imshow("detection", image) cv2.waitKey(1)
最后执行
if __name__ == '__main__': detector = TOD() with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: cap = cv2.VideoCapture(r'Your Vedio Path') n = 1 success = True while (success) : success, frame = cap.read() t1=time.clock() print('正在预测第%d张' % n) n = n + 1 if success == True: detector.detect(frame) t2=time.clock() t = t2-t1 print('cost time %f s'%t) cv2.destroyAllWindows()
即可以实现基于视频的目标目标检测
参考文档
https://gist.github.com/douglasrizzo/c70e186678f126f1b9005ca83d8bd2ce
https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。