当前位置:   article > 正文

Tensorflow学习——结合ROS调用模型实现目标识别_ros 神经网络 目标识别

ros 神经网络 目标识别

环境:Ubuntu16.04+Tensorflow-cpu-1.6.0+ROS Kinetic+OpenCV3.3.1

前期准备:

  1. 完成Object Detection api配置
  2. 完成OpenCV配置

完成模型训练后就是模型的应用,这里通过ROS利用Object Detection api调用模型实现目标物体的识别。

 

一、模型导入

模型路径设置如下图所示,注意设置目标对象类型数目。

  1. #Get models
  2. rospy.loginfo("begin initialization...")
  3. self.PATH_TO_CKPT = '../frozen_inference_graph.pb'
  4. self.PATH_TO_LABELS = '../bottel.pbtxt'
  5. self.NUM_CLASSES = 2
  6. self.detection_graph = self._load_model()
  7. self.category_index = self._load_label_map()
  8. self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
  9. self.boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
  10. self.scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
  11. self.classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
  12. self.num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')

 

二、数据处理    

调用模型识别目标对象前需进行数据处理,流程如下图所示。

  1. 相机获取的图像信息会以ROSImage Message的格式发布在ROS平台上,然后通过CvBridge对获取的图像信息进行转换,将其从ROSImage Message格式转变为Mat格式。
  2. 通过OpenCV对获取图像数据进行预处理后转换为numpy数组,然后调用ObjectDetection API进行识别。
  3. 完成图像中目标物体的识别后,识别结果以数组的形式发布到相关话题中,同时视觉识别程序会将识别出来的目标物体使用带有颜色的矩形框出来并在其上方标识识别物体的标签及其概率,然后在转换为ROSImage Message格式发布到相应话题中。

代码实现

  1. # detect object from the image
  2. def imgprogress(self, image_msg):
  3. with self.detection_graph.as_default():
  4. with tf.Session(graph=self.detection_graph) as sess:
  5. #translate image_msg data
  6. cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "rgb8")
  7. pil_img = Image.fromarray(cv_image)
  8. (im_width, im_height) = pil_img.size
  9. image_np =np.array(pil_img.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
  10. # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  11. image_np_expanded = np.expand_dims(image_np, axis=0)
  12. # Actual detection.
  13. (boxes, scores, classes, num_detections) = sess.run([self.boxes, self.scores, self.classes, self.num_detections],feed_dict={self.image_tensor: image_np_expanded})
  14. # Visualization of the results of a detection.
  15. vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),
  16. self.category_index,
  17. use_normalized_coordinates=True,
  18. line_thickness=8)
  19. #public img_msg
  20. ROSImage_pro=self._cv_bridge.cv2_to_imgmsg(image_np,encoding="rgb8")
  21. self._pub.publish(ROSImage_pro)

 

三、触发识别

因通过Object Detection API进行物体识别需要占用大量资源,所以采用动态识别的会非常卡,这里采用触发器进行触发识别,本程序设置了一个订阅器self._sub用于获取用于识别的图片,当需要进行识别时,发布图片到image_topic即可触发程序,同时结果会通过self._pub发布到object_detection话题中。

  1. # Subscribe to judge
  2. self._sub = rospy.Subscriber(image_topic, ROSImage, self.imgprogress, queue_size=10)
  3. # Subscribe to the image
  4. self._pub = rospy.Publisher('object_detection', ROSImage, queue_size=1)

完整程序

  1. #!/usr/bin/env python
  2. import rospy
  3. from sensor_msgs.msg import Image as ROSImage
  4. from cv_bridge import CvBridge
  5. import cv2
  6. import matplotlib
  7. import numpy as np
  8. import os
  9. import six.moves.urllib as urllib
  10. import sys
  11. import tarfile
  12. import tensorflow as tf
  13. import zipfile
  14. import uuid
  15. from collections import defaultdict
  16. from io import StringIO
  17. from PIL import Image
  18. from math import isnan
  19. # This is needed since the notebook is stored in the object_detection folder.
  20. from object_detection.utils import label_map_util
  21. from object_detection.utils import visualization_utils as vis_util
  22. class ObjectDetectionDemo():
  23. def __init__(self):
  24. rospy.init_node('tfobject')
  25. # Set the shutdown function (stop the robot)
  26. rospy.on_shutdown(self.shutdown)
  27. camera_topic = "/camera/rgb/image_raw" #rospy.get_param("~image_topic", "")
  28. image_topic = "/image/rgb/object"
  29. self.vfc=0
  30. self._cv_bridge = CvBridge()
  31. #Get models
  32. rospy.loginfo("begin initialization...")
  33. self.PATH_TO_CKPT = '../frozen_inference_graph.pb'
  34. self.PATH_TO_LABELS = '../bottel.pbtxt'
  35. self.NUM_CLASSES = 2
  36. self.detection_graph = self._load_model()
  37. self.category_index = self._load_label_map()
  38. self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
  39. self.boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
  40. self.scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
  41. self.classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
  42. self.num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
  43. # Subscribe to judge
  44. self._sub = rospy.Subscriber(image_topic, ROSImage, self.imgprogress, queue_size=10)
  45. # Subscribe to the image
  46. self._pub = rospy.Publisher('object_detection', ROSImage, queue_size=1)
  47. rospy.loginfo("initialization has finished...")
  48. def _load_model(self):
  49. detection_graph = tf.Graph()
  50. with detection_graph.as_default():
  51. od_graph_def = tf.GraphDef()
  52. with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
  53. serialized_graph = fid.read()
  54. od_graph_def.ParseFromString(serialized_graph)
  55. tf.import_graph_def(od_graph_def, name='')
  56. return detection_graph
  57. def _load_label_map(self):
  58. label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
  59. categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes=self.NUM_CLASSES,use_display_name=True)
  60. category_index = label_map_util.create_category_index(categories)
  61. return category_index
  62. # detect object from the image
  63. def imgprogress(self, image_msg):
  64. with self.detection_graph.as_default():
  65. with tf.Session(graph=self.detection_graph) as sess:
  66. #translate image_msg data
  67. cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "rgb8")
  68. pil_img = Image.fromarray(cv_image)
  69. (im_width, im_height) = pil_img.size
  70. image_np =np.array(pil_img.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
  71. # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  72. image_np_expanded = np.expand_dims(image_np, axis=0)
  73. # Actual detection.
  74. (boxes, scores, classes, num_detections) = sess.run([self.boxes, self.scores, self.classes, self.num_detections],feed_dict={self.image_tensor: image_np_expanded})
  75. # Visualization of the results of a detection.
  76. vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),
  77. self.category_index,
  78. use_normalized_coordinates=True,
  79. line_thickness=8)
  80. #public img_msg
  81. ROSImage_pro=self._cv_bridge.cv2_to_imgmsg(image_np,encoding="rgb8")
  82. self._pub.publish(ROSImage_pro)
  83. # stop node
  84. def shutdown(self):
  85. rospy.loginfo("Stopping the tensorflow object detection...")
  86. rospy.sleep(1)
  87. if __name__ == '__main__':
  88. try:
  89. ObjectDetectionDemo()
  90. rospy.spin()
  91. except rospy.ROSInterruptException:
  92. rospy.loginfo("RosTensorFlow_ObjectDetectionDemo has started.")

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/370527
推荐阅读
相关标签
  

闽ICP备14008679号