当前位置:   article > 正文

TensorFlow1(二)文件读取_tensorflow根据文件夹读取数据

tensorflow根据文件夹读取数据

我们首先说一下文件读取的流程(分别讨论文本文件、图片文件以及二进制文件):

 1、构造文件名队列

file_queue= tf.train.string_input_producer(string_tensor,shuffle = True)

2、读取与解码

文本:

        读取:tf.TextLineReader()

        解码:tf.decode_csv()

图片:

        读取:tf.WholeFileReader()

        解码:tf.image.decode_jepg(contents)

                   tf.image.decode_png(contents)

二进制文件:

         读取:tf.FixedLengthRecordReader(record_bytes)
         解码:tf.decode_raw()

3、批处理队列

 tf.train.batch(tensors, batch_size, num_threads = 1, capacity = 32, name=None)

手动开启线程:coord = tf.train.Coordinator()

开启会话:tf.train.start_queue_runners(sess=None, coord=None)

回收线程:coord.request_stop()   coord.join(threads)

实例一:狗的图片读取

我这里使用的是100张狗的图片,图片类型是jpg,因为每个图片的规格都是不同的,因此加入了reshape的步骤。

  1. import os
  2. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  3. import tensorflow as tf
  4. # 图片数据
  5. # 图片 - 数值(三维数组shape(图片长度、图片宽度、图片通道数))
  6. # 图片三要素
  7. # 灰度图 [长,宽,1] 每一个像素点[0,255]
  8. # 彩色图 [长,宽,3] 每一个像素点[0,255]
  9. def picture_read(file_list):
  10. # 1.构造文件名队列
  11. file_queue = tf.train.string_input_producer(file_list)
  12. # 2.读取与解码
  13. reader = tf.WholeFileReader()
  14. # key文件名 value一张图片的原始编码形式
  15. key,value = reader.read(file_queue)
  16. print(key,value)
  17. # 解码阶段
  18. image = tf.image.decode_jpeg(value)
  19. print(image)
  20. # 图像的形状、类型修改
  21. image_resized = tf.image.resize_images(image,size = [200,200])
  22. print(image_resized)
  23. #静态形状修改
  24. image_resized.set_shape(shape=[200,200,3])
  25. # 3.批处理
  26. image_batch = tf.train.batch([image_resized],batch_size=100,num_threads=1,capacity=100)
  27. print("image_batch:",image_batch)
  28. # 开启会话
  29. with tf.Session() as sess:
  30. # 开启线程
  31. # 线程协调员
  32. coord = tf.train.Coordinator()
  33. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  34. key_new,value_new,image_new,image_resized,image_batch = sess.run([key,value,image,image_resized,image_batch])
  35. print("key_new:",key_new)
  36. # print("value_new:",value_new)
  37. print("image_new",image_new)
  38. print("image_resized",image_resized)
  39. print("image_batch",image_batch)
  40. # 回收线程
  41. coord.request_stop()
  42. coord.join(threads)
  43. if __name__ == "__main__":
  44. filename = os.listdir("./dog")
  45. # print(filename)
  46. # 拼接路径 + 文件名
  47. file_list = [os.path.join("./dog/",file) for file in filename]
  48. # print(file_list)
  49. picture_read(file_list)

示例二:读取二进制文件

本示例采用的是CIFAR-10 binary version (suitable for C programs)

  1. import os
  2. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  3. import tensorflow as tf
  4. class Cifer(object):
  5. def __init__(self):
  6. # 初始化操作
  7. self.height = 32
  8. self.width = 32
  9. self.channels = 3
  10. # 字节数
  11. self.image_bytes = self.height * self.width * self.channels
  12. self.label_bytes = 1
  13. self.all_bytes = self.label_bytes + self.image_bytes
  14. def read_and_decode(self,file_list):
  15. # 构造文件名队列
  16. file_queue = tf.train.string_input_producer(file_list)
  17. # 读取与解码
  18. # 读取阶段
  19. reader = tf.FixedLengthRecordReader(self.all_bytes)
  20. # key文件名 value一个样本
  21. key,value = reader.read(file_queue)
  22. print("key:",key)
  23. print("value:",value)
  24. # 解码阶段
  25. decoded = tf.decode_raw(value,tf.uint8)
  26. print("decoded:",decoded)
  27. # 将目标值和特征值切片切开
  28. label = tf.slice(decoded,[0],[self.label_bytes])
  29. image = tf.slice(decoded,[self.label_bytes],[self.image_bytes])
  30. print("label:",label)
  31. print("image:",image)
  32. # 调整图片形状
  33. image_reshaped = tf.reshape(image,shape=[self.channels,self.height,self.width])
  34. print("image_reshaped:",image_reshaped)
  35. # 转置,将图片的顺序旋转为height,width,channels
  36. image_transposed = tf.transpose(image_reshaped,[1,2,0])
  37. print("image_transposed",image_transposed)
  38. # 调整图像类型
  39. image_cast = tf.cast(image_transposed,tf.float32)
  40. # 批处理
  41. label_batch,image_batch= tf.train.batch([label,image_cast],batch_size=100,num_threads=1,capacity=100)
  42. print("label_batch",label_batch)
  43. print("image_batch",image_batch)
  44. # 开启会话
  45. with tf.Session() as sess:
  46. # 开启线程
  47. coord = tf.train.Coordinator()
  48. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  49. key_new,value_new,decoded_new,label_new,image_new,image_reshaped_new,image_transposed_new = sess.run([key,value,decoded,label,image,image_reshaped,image_transposed])
  50. label_value,image_value = sess.run([label_batch,image_batch])
  51. print("key_new:", key_new)
  52. # print("value_new:", value_new)
  53. print("decoded_new",decoded_new)
  54. print("label_new",label_new)
  55. print("image_new",image_new)
  56. print("image_reshaped_new",image_reshaped_new)
  57. print("image_transposed_new",image_transposed_new)
  58. print("label_batch",label_batch)
  59. print("image_batch",image_batch)
  60. # 回收线程
  61. coord.request_stop()
  62. coord.join(threads)
  63. return image_value,label_value
  64. if __name__ == "__main__":
  65. file_name = os.listdir("./cifar-10-binary/cifar-10-batches-bin")
  66. # print("file_name:",file_name)
  67. # 构造文件名路径列表
  68. file_list = [os.path.join("./cifar-10-binary/cifar-10-batches-bin/",file) for file in file_name if file[-3:] == 'bin']
  69. # print("file_list:",file_list)
  70. # 实例化Cifar
  71. cifar = Cifer()
  72. image_batch,label_batch = cifar.read_and_decode(file_list)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/249805
推荐阅读
相关标签
  

闽ICP备14008679号