当前位置:   article > 正文

object detection api 之 maskrcnn 训练自己的数据集_maskrcnn网络训练数据集

maskrcnn网络训练数据集

使用labelme标注软件标注数据,每张图片生成一个json文件,划分训练集和测试集,再划分jpg和json文件

1,运行create_tf_record.py

  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Sun Aug 26 10:57:09 2018
  5. @author: shirhe-lyh
  6. """
  7. """Convert raw dataset to TFRecord for object_detection.
  8. Please note that this tool only applies to labelme's annotations(json file).
  9. Example usage:
  10. python3 create_tf_record.py \
  11. --images_dir=your absolute path to read images.
  12. --annotations_json_dir=your path to annotaion json files.
  13. --label_map_path=your path to label_map.pbtxt
  14. --output_path=your path to write .record.
  15. """
  16. import cv2
  17. import glob
  18. import hashlib
  19. import io
  20. import json
  21. import numpy as np
  22. import os
  23. import PIL.Image
  24. import tensorflow as tf
  25. import read_pbtxt_file
  26. flags = tf.app.flags
  27. flags.DEFINE_string('images_dir', None, 'Path to images directory.')
  28. flags.DEFINE_string('annotations_json_dir', 'datasets/annotations',
  29. 'Path to annotations directory.')
  30. flags.DEFINE_string('label_map_path', None, 'Path to label map proto.')
  31. flags.DEFINE_string('output_path', None, 'Path to the output tfrecord.')
  32. FLAGS = flags.FLAGS
  33. def int64_feature(value):
  34. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  35. def int64_list_feature(value):
  36. return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
  37. def bytes_feature(value):
  38. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  39. def bytes_list_feature(value):
  40. return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
  41. def float_list_feature(value):
  42. return tf.train.Feature(float_list=tf.train.FloatList(value=value))
  43. def create_tf_example(annotation_dict, label_map_dict=None):
  44. """Converts image and annotations to a tf.Example proto.
  45. Args:
  46. annotation_dict: A dictionary containing the following keys:
  47. ['height', 'width', 'filename', 'sha256_key', 'encoded_jpg',
  48. 'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks',
  49. 'class_names'].
  50. label_map_dict: A dictionary maping class_names to indices.
  51. Returns:
  52. example: The converted tf.Example.
  53. Raises:
  54. ValueError: If label_map_dict is None or is not containing a class_name.
  55. """
  56. if annotation_dict is None:
  57. return None
  58. if label_map_dict is None:
  59. raise ValueError('`label_map_dict` is None')
  60. height = annotation_dict.get('height', None)
  61. width = annotation_dict.get('width', None)
  62. filename = annotation_dict.get('filename', None)
  63. sha256_key = annotation_dict.get('sha256_key', None)
  64. encoded_jpg = annotation_dict.get('encoded_jpg', None)
  65. image_format = annotation_dict.get('format', None)
  66. xmins = annotation_dict.get('xmins', None)
  67. xmaxs = annotation_dict.get('xmaxs', None)
  68. ymins = annotation_dict.get('ymins', None)
  69. ymaxs = annotation_dict.get('ymaxs', None)
  70. masks = annotation_dict.get('masks', None)
  71. class_names = annotation_dict.get('class_names', None)
  72. print("class_names:",class_names)
  73. labels = []
  74. for class_name in class_names:
  75. label = label_map_dict.get(class_name, 'None')
  76. print("label:",label)
  77. if label is None:
  78. raise ValueError('`label_map_dict` is not containing {}.'.format(
  79. class_name))
  80. labels.append(label)
  81. encoded_masks = []
  82. for mask in masks:
  83. pil_image = PIL.Image.fromarray(mask.astype(np.uint8))
  84. output_io = io.BytesIO()
  85. pil_image.save(output_io, format='PNG')
  86. encoded_masks.append(output_io.getvalue())
  87. feature_dict = {
  88. 'image/height': int64_feature(height),
  89. 'image/width': int64_feature(width),
  90. 'image/filename': bytes_feature(filename.encode('utf8')),
  91. 'image/source_id': bytes_feature(filename.encode('utf8')),
  92. 'image/key/sha256': bytes_feature(sha256_key.encode('utf8')),
  93. 'image/encoded': bytes_feature(encoded_jpg),
  94. 'image/format': bytes_feature(image_format.encode('utf8')),
  95. 'image/object/bbox/xmin': float_list_feature(xmins),
  96. 'image/object/bbox/xmax': float_list_feature(xmaxs),
  97. 'image/object/bbox/ymin': float_list_feature(ymins),
  98. 'image/object/bbox/ymax': float_list_feature(ymaxs),
  99. 'image/object/mask': bytes_list_feature(encoded_masks),
  100. 'image/object/class/label': int64_list_feature(labels)}
  101. example = tf.train.Example(features=tf.train.Features(
  102. feature=feature_dict))
  103. return example
  104. def _get_annotation_dict(images_dir, annotation_json_path):
  105. """Get boundingboxes and masks.
  106. Args:
  107. images_dir: Path to images directory.
  108. annotation_json_path: Path to annotated json file corresponding to
  109. the image. The json file annotated by labelme with keys:
  110. ['lineColor', 'imageData', 'fillColor', 'imagePath', 'shapes',
  111. 'flags'].
  112. Returns:
  113. annotation_dict: A dictionary containing the following keys:
  114. ['height', 'width', 'filename', 'sha256_key', 'encoded_jpg',
  115. 'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks',
  116. 'class_names'].
  117. #
  118. # Raises:
  119. # ValueError: If images_dir or annotation_json_path is not exist.
  120. """
  121. # if not os.path.exists(images_dir):
  122. # raise ValueError('`images_dir` is not exist.')
  123. #
  124. # if not os.path.exists(annotation_json_path):
  125. # raise ValueError('`annotation_json_path` is not exist.')
  126. if (not os.path.exists(images_dir) or
  127. not os.path.exists(annotation_json_path)):
  128. return None
  129. with open(annotation_json_path, 'r') as f:
  130. json_text = json.load(f)
  131. shapes = json_text.get('shapes', None)
  132. if shapes is None:
  133. return None
  134. image_relative_path = json_text.get('imagePath', None)
  135. print("imagePath",image_relative_path)
  136. if image_relative_path is None:
  137. return None
  138. image_name = image_relative_path.split('/')[-1]
  139. image_path = os.path.join(images_dir, image_name)
  140. image_format = image_name.split('.')[-1].replace('jpg', 'jpeg')
  141. if not os.path.exists(image_path):
  142. return None
  143. with tf.gfile.GFile(image_path, 'rb') as fid:
  144. encoded_jpg = fid.read()
  145. image = cv2.imread(image_path)
  146. height = image.shape[0]
  147. width = image.shape[1]
  148. key = hashlib.sha256(encoded_jpg).hexdigest()
  149. xmins = []
  150. xmaxs = []
  151. ymins = []
  152. ymaxs = []
  153. masks = []
  154. class_names = []
  155. hole_polygons = []
  156. for mark in shapes:
  157. class_name = mark.get('label')
  158. class_names.append(class_name)
  159. polygon = mark.get('points')
  160. polygon = np.array(polygon,dtype=np.int32)
  161. if class_name == 'hole':
  162. hole_polygons.append(polygon)
  163. else:
  164. mask = np.zeros(image.shape[:2])
  165. cv2.fillPoly(mask, [polygon], 1)
  166. masks.append(mask)
  167. # Boundingbox
  168. x = polygon[:, 0]
  169. y = polygon[:, 1]
  170. xmin = np.min(x)
  171. xmax = np.max(x)
  172. ymin = np.min(y)
  173. ymax = np.max(y)
  174. xmins.append(float(xmin) / width)
  175. xmaxs.append(float(xmax) / width)
  176. ymins.append(float(ymin) / height)
  177. ymaxs.append(float(ymax) / height)
  178. # Remove holes in mask
  179. for mask in masks:
  180. mask = cv2.fillPoly(mask, hole_polygons, 0)
  181. annotation_dict = {'height': height,
  182. 'width': width,
  183. 'filename': image_name,
  184. 'sha256_key': key,
  185. 'encoded_jpg': encoded_jpg,
  186. 'format': image_format,
  187. 'xmins': xmins,
  188. 'xmaxs': xmaxs,
  189. 'ymins': ymins,
  190. 'ymaxs': ymaxs,
  191. 'masks': masks,
  192. 'class_names': class_names}
  193. return annotation_dict
  194. def main(_):
  195. if not os.path.exists(FLAGS.images_dir):
  196. raise ValueError('`images_dir` is not exist.')
  197. if not os.path.exists(FLAGS.annotations_json_dir):
  198. raise ValueError('`annotations_json_dir` is not exist.')
  199. if not os.path.exists(FLAGS.label_map_path):
  200. raise ValueError('`label_map_path` is not exist.')
  201. label_map = read_pbtxt_file.get_label_map_dict(FLAGS.label_map_path)
  202. writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  203. num_annotations_skiped = 0
  204. annotations_json_path = os.path.join(FLAGS.annotations_json_dir, '*.json')
  205. for i, annotation_file in enumerate(glob.glob(annotations_json_path)):
  206. if i % 100 == 0:
  207. print('On image %d', i)
  208. annotation_dict = _get_annotation_dict(FLAGS.images_dir, annotation_file)
  209. if annotation_dict is None:
  210. num_annotations_skiped += 1
  211. continue
  212. tf_example = create_tf_example(annotation_dict, label_map)
  213. writer.write(tf_example.SerializeToString())
  214. print('Successfully created TFRecord to {}.'.format(FLAGS.output_path))
  215. if __name__ == '__main__':
  216. tf.app.run()

运行两次,分别生成train.tfrecord和test.tfrecord

之后和faster_rcnn训练方式一样

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/812775
推荐阅读
相关标签
  

闽ICP备14008679号