当前位置:   article > 正文

Tensorflow构建自己的图片数据集TFrecords_tensorflow 自己的图片 数组

tensorflow 自己的图片 数组

 TensorFlow提供了标准的TFRecord 格式,而关于 tensorflow 读取数据, 官网也提供了3中方法 :
1 Feeding: 在tensorflow程序运行的每一步, 用python代码在线提供数据
2 Reader : 在一个计算图(tf.graph)的开始前,将文件读入到流(queue)中
3 在声明tf.variable变量或numpy数组时保存数据。受限于内存大小,适用于数据较小的情况

在本文,主要介绍第二种方法,利用tf.record标准接口来读入文件

第一步,准备数据
先在网上下载一些不同类的图片集,例如猫、狗等,也可以是同一种类,不同类型的,例如哈士奇、吉娃娃等都属于狗类;此处笔者预先下载了哈士奇、吉娃娃两种狗的照片各20张,并分别将其放置在不同文件夹下。如下:

第二步,制作TFRecord文件
注意:tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签 如在本例中,只有0,1 两类

  1. # -----------------------------------------------------------------------------
  2. # encoding=utf-8
  3. import os
  4. import tensorflow as tf
  5. from PIL import Image
  6. cwd = 'ferro/train//'
  7. classes = {'Cutting', 'Fatigue','Normal'}
  8. # 制作TFRecords数据
  9. def create_record():
  10. writer = tf.python_io.TFRecordWriter("ferro_train.tfrecords")
  11. for index, name in enumerate(classes):
  12. class_path = cwd + "/" + name + "/"
  13. for img_name in os.listdir(class_path):
  14. img_path = class_path + img_name
  15. img = Image.open(img_path)
  16. img = img.resize((64, 64))
  17. img_raw = img.tobytes() # 将图片转化为原生bytes
  18. print(index, img_raw)
  19. example = tf.train.Example(
  20. features=tf.train.Features(feature={
  21. "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
  22. 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
  23. }))
  24. writer.write(example.SerializeToString())
  25. writer.close()

将上面的代码编辑完成后,点击运行,就会生成一个ferro_train.TFRecords文件,如下图所示:


TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

第三步,读取TFRecord文件

  1. # 读取二进制数据
  2. def read_and_decode(filename):
  3. # 创建文件队列,不限读取的数量
  4. filename_queue = tf.train.string_input_producer([filename])
  5. # create a reader from file queue
  6. reader = tf.TFRecordReader()
  7. # reader从文件队列中读入一个序列化的样本
  8. _, serialized_example = reader.read(filename_queue)
  9. # get feature from serialized example
  10. # 解析符号化的样本
  11. features = tf.parse_single_example(
  12. serialized_example,
  13. features={
  14. 'label': tf.FixedLenFeature([], tf.int64),
  15. 'img_raw': tf.FixedLenFeature([], tf.string)
  16. })
  17. label = features['label']
  18. img = features['img_raw']
  19. img = tf.decode_raw(img, tf.uint8)
  20. img = tf.reshape(img, [64, 64, 3])
  21. # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  22. label = tf.cast(label, tf.int32)
  23. return img, label


一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List。另外,需要我们注意的是:feature的属性“label”和“img_raw”名称要和制作时统一 ,返回的img数据和label数据一一对应。
第四步,TFRecord的显示操作
如果想要检查分类是否有误,或者在之后的网络训练过程中可以监视,输出图片,来观察分类等操作的结果,那么我们就可以session回话中,将tfrecord的图片从流中读取出来,再保存。因而自然少不了主程序的存在。

 

  1. if __name__ == '__main__':
  2. create_record()
  3. batch = read_and_decode('ferro_train.tfrecords')
  4. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  5. with tf.Session() as sess: # 开始一个会话
  6. sess.run(init_op)
  7. coord = tf.train.Coordinator()
  8. threads = tf.train.start_queue_runners(coord=coord)
  9. for i in range(200):
  10. example, lab = sess.run(batch) # 在会话中取出image和label
  11. img = Image.fromarray(example, 'RGB') # 这里Image是之前提到的
  12. img.save(cwd + '/' + str(i) + '_Label_' + str(lab) + '.jpg') # 存下图片;注意cwd后边加上‘/’
  13. print(example, lab)
  14. coord.request_stop()
  15. coord.join(threads)
  16. sess.close()


进过上面的一通操作之后,我们便可以得到和tensorflow官方的二进制数据集一样的数据集了,并且可以按照自己的设计来进行。
下面附上该程序的完整代码,仅供参考。

 

  1. # -----------------------------------------------------------------------------
  2. # encoding=utf-8
  3. import os
  4. import tensorflow as tf
  5. from PIL import Image
  6. cwd = 'ferro/train//'
  7. classes = {'Cutting', 'Fatigue','Normal'}
  8. # 制作TFRecords数据
  9. def create_record():
  10. writer = tf.python_io.TFRecordWriter("ferro_train.tfrecords")
  11. for index, name in enumerate(classes):
  12. class_path = cwd + "/" + name + "/"
  13. for img_name in os.listdir(class_path):
  14. img_path = class_path + img_name
  15. img = Image.open(img_path)
  16. img = img.resize((64, 64))
  17. img_raw = img.tobytes() # 将图片转化为原生bytes
  18. print(index, img_raw)
  19. example = tf.train.Example(
  20. features=tf.train.Features(feature={
  21. "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
  22. 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
  23. }))
  24. writer.write(example.SerializeToString())
  25. writer.close()
  26. # -------------------------------------------------------------------------
  27. # 读取二进制数据
  28. def read_and_decode(filename):
  29. # 创建文件队列,不限读取的数量
  30. filename_queue = tf.train.string_input_producer([filename])
  31. # create a reader from file queue
  32. reader = tf.TFRecordReader()
  33. # reader从文件队列中读入一个序列化的样本
  34. _, serialized_example = reader.read(filename_queue)
  35. # get feature from serialized example
  36. # 解析符号化的样本
  37. features = tf.parse_single_example(
  38. serialized_example,
  39. features={
  40. 'label': tf.FixedLenFeature([], tf.int64),
  41. 'img_raw': tf.FixedLenFeature([], tf.string)
  42. })
  43. label = features['label']
  44. img = features['img_raw']
  45. img = tf.decode_raw(img, tf.uint8)
  46. img = tf.reshape(img, [64, 64, 3])
  47. # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  48. label = tf.cast(label, tf.int32)
  49. return img, label
  50. # --------------------------------------------------------------------------
  51. # ---------主程序----------------------------------------------------------
  52. if __name__ == '__main__':
  53. create_record()
  54. #batch = read_and_decode('ferro_train.tfrecords')
  55. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  56. with tf.Session() as sess: # 开始一个会话
  57. sess.run(init_op)
  58. coord = tf.train.Coordinator()
  59. threads = tf.train.start_queue_runners(coord=coord)
  60. for i in range(200):
  61. example, lab = sess.run(batch) # 在会话中取出image和label
  62. img = Image.fromarray(example, 'RGB') # 这里Image是之前提到的
  63. img.save(cwd + '/' + str(i) + '_Label_' + str(lab) + '.jpg') # 存下图片;注意cwd后边加上‘/’
  64. print(example, lab)
  65. coord.request_stop()
  66. coord.join(threads)
  67. sess.close()
  68. # -----------------------------------------------------------------------------


运行上述的完整代码,便可以 将从TFRecord中取出的文件保存下来了。如下图:

每一幅图片的命名中,第二个数字则是 label,Cut都为1,Normal都为0;通过对照图片,可以发现图片分类正确。
 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/123794
推荐阅读
相关标签
  

闽ICP备14008679号