当前位置:   article > 正文

【TensorFlow】用TFRecord方式对数据进行读取(一)_tfrecord 数据读取

tfrecord 数据读取

在做深度学习项目时,在模型训练前,通常要对训练/验证图像进行读取操作。之前博文《TensorFlow 卷积神经网络 - 猫狗识别》使用的是OpenCV读取的方式。使用OpenCV把图像读成矩阵形式当然可以满足模型训练的要求,此方式在处理小批量图像时还可以,如果处理大批量图像,就显得有点慢了。

对于大型项目、大批量的图像,经常用TFRecord的方式对数据进行读取。TFRecord是TensorFlow支持的格式,速度快,1W以上的量建议使用TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据。其优势是能更好的利用内存,更方便地复制和移动,这更符合TensorFlow执行引擎的处理方式。通常数据转换成tfrecord格式需要写个小程序将每一个样本组装成protocol buffer定义的Example的对象,序列化成字符串,再由tf.python_io.TFRecordWriter写入文件即可。

在使用TFRecord方式读取数据之前,通常需要把相同类型的数据放在同一个文件夹。例如:

上图中,“flower_photos”为总文件夹,里面放了5个子文件夹,即把所有的玫瑰图片放到“roses”文件夹,所有的向日葵图片放到“sunflowers”文件夹,等等。这样做的目的是方便完成“图片路径”--“图片标签(例:1、2、3)”--“图片名称(例:daisy、dandelion、roses)”之间的映射。

roses文件夹下的图片:

程序实现

目录结构:

flower_label.txt:

此文件的内容存放./flower_photos目录下的5个子文件名称,方便程序读取图片。

  1. daisy
  2. dandelion
  3. roses
  4. sunflowers
  5. tulips

build_image_data.py:

  1. # coding=utf-8
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. from datetime import datetime
  6. import os
  7. import random
  8. import sys
  9. # 多线程制作,速度更快。数据预处理、建立数据源写一块
  10. import threading
  11. import numpy as np
  12. import tensorflow as tf
  13. # 定义string和int类型参数
  14. # 没演示验证集,只有训练集,可以在目录里面加上验证集。train_directory为参数名
  15. tf.app.flags.DEFINE_string('train_directory', './flower_photos/', 'Training data directory')
  16. # 验证集,未指定单独的验证集,偷懒
  17. tf.app.flags.DEFINE_string('validation_directory', './flower_photos/', 'Validation data directory')
  18. # TFRecord输出目录
  19. tf.app.flags.DEFINE_string('output_directory', './data/', 'Output data directory')
  20. # 想生成几个TFrecord文件,train_shards / num_threads 要能够整除,这样才好能分配数量
  21. tf.app.flags.DEFINE_integer('train_shards', 2, 'Number of shards in training TFRecord files.')
  22. # 同上,不做验证集,只做训练集
  23. tf.app.flags.DEFINE_integer('validation_shards', 0, 'Number of shards in validation TFRecord files.')
  24. # 启动线程的个数
  25. tf.app.flags.DEFINE_integer('num_threads', 2, 'Number of threads to preprocess the images.')
  26. # The labels file contains a list of valid labels are held in this file .
  27. # Assumes that the file contains entries as such:
  28. # dog
  29. # cat
  30. # flower
  31. # where each line corresponds to a labels. We map each label contained in
  32. # the file to an integer corresponding to the line number starting from 0.
  33. # flower_label.txt和子文件夹的名字一一对应
  34. tf.app.flags.DEFINE_string('labels_file', './flower_label.txt', 'labels file')
  35. # 获得上述定义的参数
  36. FLAGS = tf.app.flags.FLAGS
  37. def _int64_feature(value):
  38. """Wrapper for inserting int64 feature into Example proto."""
  39. """isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()
  40. isinstance() 与 type() 区别:
  41. type() 不会认为子类是一种父类类型,不考虑继承关系。
  42. isinstance() 会认为子类是一种父类类型,考虑继承关系。
  43. """
  44. if not isinstance(value, list):
  45. value = [value]
  46. return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
  47. def _bytes_feature(value):
  48. """Wrapper for inserting bytes features into Example proto."""
  49. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  50. def _find_image_files(data_dir, labels_file):
  51. """
  52. Build a list of all images files and labels in the data set.
  53. :param data_dir: string, path to the root directory of images.
  54. :param labels_file: string, path to the labels file.
  55. The list of valid labels are held in this file, Assumes that the file contains entries as such:
  56. dog
  57. cat flower
  58. where each line corresponds to a label. We map each label contained in the file to an integer staring with the
  59. integer 0 corresponding to the label contained in the first line.
  60. :return:
  61. filenames: list of strings; each string is a path to an image file.
  62. texts: list of strings; each string is the class, e.g. 'dog'
  63. labels: list of integer; each integer identifies the ground truth.
  64. """
  65. print('目标文件夹位置:%s.' % data_dir)
  66. # 读flower_label.txt文件的内容
  67. """tf.gfile.FastGFile(path, decodestyle)
  68. 函数功能:实现对图片的读取。
  69. 函数参数:(1)path:图片所在路径 (2)decodestyle:图片的解码方式。(‘r’:UTF-8编码; ‘rb’:非UTF-8编码)
  70. """
  71. unique_labels = [l.strip() for l in tf.gfile.FastGFile(labels_file, 'r').readlines()]
  72. labels = []
  73. filenames = []
  74. texts = []
  75. # Leave label index 0 empty as a background class.
  76. label_index = 1
  77. # Construct the list of JPEG files and labels.
  78. for text in unique_labels:
  79. jpeg_file_path = '%s/%s/*' % (data_dir, text)
  80. try:
  81. # tf.gfile.Glob()用于返回与给定模式匹配的文件列表
  82. matching_files = tf.gfile.Glob(jpeg_file_path)
  83. except:
  84. print(jpeg_file_path)
  85. continue
  86. # 从“1”开始,扩充每一图片类别的labels
  87. labels.extend([label_index] * len(matching_files))
  88. # 根据flower_label.txt内容,扩充texts
  89. texts.extend([text] * len(matching_files))
  90. filenames.extend(matching_files)
  91. label_index += 1
  92. # shuffle the ordering of all image files in order to guarantee
  93. # random ordering of the images with respect to label in the
  94. # saved TFRecord files. Make the randomization repeatable.
  95. # 洗牌,把当前顺序打乱,标签为1、2、3、4、5、打乱
  96. shuffled_index = list(range(len(filenames)))
  97. # 保证shuffled_index之后每次的随机一样
  98. random.seed(12345)
  99. random.shuffle(shuffled_index)
  100. # 数据重新排列,执行完shuffle之后,数据可以对应上
  101. filenames = [filenames[i] for i in shuffled_index]
  102. texts = [texts[i] for i in shuffled_index]
  103. labels = [labels[i] for i in shuffled_index]
  104. print('Found %d JPEG files across %d labels inside %s.' % (len(filenames), len(unique_labels), data_dir))
  105. return filenames, texts, labels
  106. class ImageCoder(object):
  107. """Helper class that provides TensorFlow image coding utilities."""
  108. # 把所有图片转换成.jpg的RGB的形式
  109. def __init__(self):
  110. # Create a single Session to run all image coding calls.
  111. self._sess = tf.Session()
  112. # Initializes function that converts PNG to JPEG data.
  113. # 确保所有图像格式都相同
  114. self._png_data = tf.placeholder(dtype=tf.string)
  115. # 解码为3通道
  116. image = tf.image.decode_png(self._png_data, channels=3)
  117. # 编码为RGB
  118. self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
  119. # Initializes function that decodes RGB JPEG data.
  120. self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
  121. self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
  122. def png_to_jpeg(self, image_data):
  123. return self._sess.run(self._png_to_jpeg, feed_dict={self._png_data: image_data})
  124. def decode_jpeg(self, image_data):
  125. image = self._sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data})
  126. assert len(image.shape) == 3
  127. assert image.shape[2] == 3
  128. return image
  129. def _process_image(filename, coder):
  130. """
  131. Process a single image file.
  132. :param filename: string, path to an image file e.g., '/path/to/example.JPG'.
  133. :param coder: instance of ImageCoder to provide TensorFlow image coding utils.
  134. :return: image_buffer: string, JPEG encoding of RGB image.
  135. height: integer, image height in pixels.
  136. width: integer, image width in pixels.
  137. """
  138. # Read the image file.
  139. with tf.gfile.FastGFile(filename, 'rb') as f:
  140. image_data = f.read()
  141. # Convert any PNG to JPEG's for consistency.
  142. if _is_png(filename):
  143. print('Converting PNG to JPEG for %s' % filename)
  144. image_data = coder.png_to_jpeg(image_data)
  145. # Decode the RGB JPEG.
  146. image = coder.decode_jpeg(image_data)
  147. # Check that image converted to RGB. h, w, channel
  148. assert len(image.shape) == 3
  149. height = image.shape[0]
  150. width = image.shape[1]
  151. # 判断是否三通道
  152. assert image.shape[2] == 3
  153. return image_data, height, width
  154. def _is_png(filename):
  155. """
  156. Determine if a file contains a PNG format image.
  157. :param filename: string, path of the iamge file.
  158. :return: boolean indicating if the image is a PNG.
  159. """
  160. return '.png' in filename
  161. def _convert_to_example(filename, image_buffer, label, text, height, width):
  162. """
  163. Build an Example proto for an example.
  164. :param filename: string, path to an image file, e.g., '/path/to/example.JPG'
  165. :param image_buffer: string, JPEG encoding of RGB image
  166. :param label: integer, identifier for the ground truth for the network
  167. :param text: string, unique human-readable, e.g. 'dog'
  168. :param height: integer, image height in pixels
  169. :param width: integer, image width in pixels
  170. :return: Example proto
  171. """
  172. colorspace = 'RGB'
  173. channels = 3
  174. image_format = 'JPEG'
  175. # tf.compat.as_bytes(),将字节或unicode转换为字节,使用utf-8编码文本
  176. example = tf.train.Example(features=tf.train.Features(feature={
  177. 'image/height': _int64_feature(height),
  178. 'image/width': _int64_feature(width),
  179. 'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)),
  180. 'image/channels': _int64_feature(channels),
  181. 'image/class/label': _int64_feature(label),
  182. 'image/class/text': _bytes_feature(tf.compat.as_bytes(text)),
  183. 'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
  184. 'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
  185. 'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))
  186. # 'image/encoded': _bytes_feature(image_buffer)
  187. }))
  188. return example
  189. def _process_image_files_batch(coder, thread_index, ranges, name, filenames, texts, labels, num_shards):
  190. """
  191. Processes and saves list of images as TFRecord in 1 thread.
  192. :param coder: instance of ImageCoder to provide TensorFlow image coding utils.
  193. :param thread_index: integer, unique batch to run index is within [0, len(ranges)].
  194. :param ranges: list of pairs of integers specifying ranges of each batches to analyze in parallel.
  195. :param name: string, unique identifier specifying the data set.
  196. :param filenames: list of strings; each string is a path to an image file.
  197. :param texts: list of strings; each string is human readable, e.g. 'dog'.
  198. :param labels: list of integer; each integer identifies the ground truth.
  199. :param num_shards: integer number of shards for this data set.
  200. :return:
  201. """
  202. # Each thread produces N shards where N=int(num_shards / num_threads).
  203. # For instance, if num_shards=128, and the num_threads=2, then the first thread would produce shards[0, 64].
  204. num_threads = len(ranges)
  205. assert not num_shards % num_threads
  206. num_shards_per_batch = int(num_shards / num_threads)
  207. shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1], num_shards_per_batch + 1).astype(int)
  208. num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
  209. counter = 0
  210. for s in range(num_shards_per_batch):
  211. # Generate a sharded version of the file name, e.g. 'train-00001-of-00002'
  212. shard = thread_index * num_shards_per_batch + s
  213. output_filename = '%s-%.5d-of-%.5d.tfrecord' % (name, shard, num_shards)
  214. output_file = os.path.join(FLAGS.output_directory, output_filename)
  215. writer = tf.python_io.TFRecordWriter(output_file)
  216. shard_counter = 0
  217. files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
  218. for i in files_in_shard:
  219. filename = filenames[i] # 全路径
  220. label = labels[i] # 标签
  221. text = texts[i] # 文件夹名称
  222. image_buffer, height, width = _process_image(filename, coder)
  223. example = _convert_to_example(filename, image_buffer, label, text, height, width)
  224. writer.write(example.SerializeToString())
  225. shard_counter += 1
  226. counter += 1
  227. if not counter % 1000:
  228. print('%s [thread %d]: Processed %d of %d image in thread batch.' % (
  229. datetime.now(), thread_index, counter, num_files_in_thread))
  230. sys.stdout.flush()
  231. writer.close()
  232. print('%s [thread %d]: Wrote %d images to %s' % (datetime.now(), thread_index, shard_counter, output_file))
  233. # 关闭多线程
  234. sys.stdout.flush()
  235. shard_counter = 0
  236. print(
  237. '%s [thread %d]: Wrote %d images to %d shards.' % (datetime.now(), thread_index, counter, num_files_in_thread))
  238. sys.stdout.flush()
  239. def _process_image_files(name, filenames, texts, labels, num_shards):
  240. """
  241. Process and save list of image as TFRecord of Example protos.
  242. :param name: string, unique identifier specifying the data set
  243. :param filenames: list of strings; each string is a path to an image file
  244. :param texts: list of strings; each string is human readable, e.g.'dog
  245. :param labels: list of integer identifies the ground truth
  246. :param num_shards: integer number os shards for this data set.
  247. :return:
  248. """
  249. # filenames、texts、labels数量相对应
  250. assert len(filenames) == len(texts)
  251. assert len(filenames) == len(labels)
  252. # Break all images into batches with a [ranges[i][0], ranges[i][1]].
  253. # [0, 1835, 3670],从0至1835交给一个线程做;1835至3670交给另一个线程完成。
  254. spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
  255. # 把spacing分成两部分,得到[0, 1835]和[1835, 3670]
  256. ranges = []
  257. for i in range(len(spacing) - 1):
  258. ranges.append([spacing[i], spacing[i + 1]])
  259. # Launch a thread for each batch.
  260. print('launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
  261. sys.stdout.flush()
  262. # Create a mechanism for monitoring when all threads are finished.
  263. # TensorFlow的线程管理器
  264. coord = tf.train.Coordinator()
  265. # Create a generic TensorFlow-based utility for converting all image coding.
  266. coder = ImageCoder()
  267. threads = []
  268. for thread_index in range(len(ranges)):
  269. args = (coder, thread_index, ranges, name, filenames, texts, labels, num_shards)
  270. t = threading.Thread(target=_process_image_files_batch, args=args)
  271. t.start()
  272. threads.append(t)
  273. # Wait for all the threads to terminate.
  274. coord.join(threads)
  275. print('%s: Finished writing all %d images in data set.' % (datetime.now(), len(filenames)))
  276. sys.stdout.flush()
  277. def _process_dataset(name, directory, num_shards, labels_file):
  278. """Process a complete data set and save it as a TFRecord.
  279. Args:
  280. name: string, unique identifier specifying the data set.
  281. directory: string, root path to the data set.
  282. num_shards: integer number if shards for this data set.
  283. labels_file: string, path to the labels file.
  284. """
  285. filenames, texts, labels = _find_image_files(directory, labels_file)
  286. _process_image_files(name, filenames, texts, labels, num_shards)
  287. def main(unused_argv):
  288. assert not FLAGS.train_shards % FLAGS.num_threads, ('在测试集中,线程数量应用建立文件个数相对应')
  289. assert not FLAGS.validation_shards % FLAGS.num_threads, ('在验证集中,线程数量应用建立文件个数相对应')
  290. print('生成数据文件夹%s' % FLAGS.output_directory)
  291. # run it!
  292. # 训练集
  293. _process_dataset('train', FLAGS.train_directory, FLAGS.train_shards, FLAGS.labels_file)
  294. # 验证集
  295. # _process_dataset('validation', FLAGS.validation_directory, FLAGS.validation_shards, FLAGS.labels_file)
  296. if __name__ == '__main__':
  297. tf.app.run()

执行结果:

  1. 生成数据文件夹./data/
  2. 目标文件夹位置:./flower_photos/.
  3. Instructions for updating:
  4. Use tf.gfile.GFile.
  5. Found 3670 JPEG files across 5 labels inside ./flower_photos/.
  6. launching 2 threads for spacings: [[0, 1835], [1835, 3670]]
  7. 2019-08-28 12:49:17.142402 [thread 0]: Processed 1000 of 1835 image in thread batch.
  8. 2019-08-28 12:49:17.362402 [thread 1]: Processed 1000 of 1835 image in thread batch.
  9. 2019-08-28 12:49:25.261402 [thread 0]: Wrote 1835 images to ./data/train-00000-of-00002.tfrecord
  10. 2019-08-28 12:49:25.261402 [thread 0]: Wrote 1835 images to 1835 shards.
  11. 2019-08-28 12:49:25.810402 [thread 1]: Wrote 1835 images to ./data/train-00001-of-00002.tfrecord
  12. 2019-08-28 12:49:25.810402 [thread 1]: Wrote 1835 images to 1835 shards.
  13. 2019-08-28 12:49:26.274402: Finished writing all 3670 images in data set.

生成的TFRecord文件:

参考:

https://blog.csdn.net/moyu123456789/article/details/83956366

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/123770
推荐阅读
相关标签
  

闽ICP备14008679号