当前位置:   article > 正文

TensorFlow详解猫狗识别(一)--读取自己的数据集_def read_image_filenames(data_dir): cat_dir = data

def read_image_filenames(data_dir): cat_dir = data_dir + 'cat/' dog_dir = da

数据集下载

链接: https://pan.baidu.com/s/1SlNAPf3NbgPyf93XluM7Fg 密码: hpn4

数据集分别有12500张cat,12500张dog

读取数据集

数据集的读取,查阅了那么多文档,大致了解到,数据集的读取方法大概会分为两种

1、先生成图片list,和标签list,把图片名称和标签对应起来,再读取制作迭代器(个人认为此方法一般用在,图片名称上可以明确的知道label的)

2、直接生成TFRecord文件,用tf.TFRecordReader()来读取,个人认为,当图片量很大的时候(如:ImageNet)很使用,保存了TFRecord文件后,一劳永逸,省去了生成list的过程

下面贴出代码,简单介绍两种读取数据集的方式。

方法一:

  1. import os
  2. import tensorflow as tf
  3. from PIL import Image
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import cv2
  7. # os模块包含操作系统相关的功能,
  8. # 可以处理文件和目录这些我们日常手动需要做的操作。因为我们需要获取test目录下的文件,所以要导入os模块。
  9. #
  10. # 数据构成,在训练数据中,There are 12500 cat,There are 12500 dogs,共25000张
  11. # 获取文件路径和标签
  12. def get_files(file_dir):
  13. # file_dir: 文件夹路径
  14. # return: 乱序后的图片和标签
  15. cats = []
  16. label_cats = []
  17. dogs = []
  18. label_dogs = []
  19. # 载入数据路径并写入标签值
  20. for file in os.listdir(file_dir):
  21. name = file.split(sep='.')
  22. # name的形式为['dog', '9981', 'jpg']
  23. # os.listdir将名字转换为列表表达
  24. if name[0] == 'cat':
  25. cats.append(file_dir + file)
  26. # 注意文件路径和名字之间要加分隔符,不然后面查找图片会提示找不到图片
  27. # 或者在后面传路径的时候末尾加两// 'D:/Python/neural network/Cats_vs_Dogs/data/train//'
  28. label_cats.append(0)
  29. else:
  30. dogs.append(file_dir + file)
  31. label_dogs.append(1)
  32. # 猫为0,狗为1
  33. print("There are %d cats\nThere are %d dogs" % (len(cats), len(dogs)))
  34. # 打乱文件顺序
  35. image_list = np.hstack((cats, dogs))
  36. label_list = np.hstack((label_cats, label_dogs))
  37. # np.hstack()方法将猫和狗图片和标签整合到一起,标签也整合到一起
  38. temp = np.array([image_list, label_list])
  39. # 这里的数组出来的是2行10列,第一行是image_list的数据,第二行是label_list的数据
  40. temp = temp.transpose() # 转置
  41. # 将其转换为10行2列,第一列是image_list的数据,第二列是label_list的数据
  42. np.random.shuffle(temp)
  43. # 对应的打乱顺序
  44. image_list = list(temp[:, 0]) # 取所有行的第0列数据
  45. label_list = list(temp[:, 1]) # 取所有行的第1列数据,并转换为int
  46. label_list = [int(i) for i in label_list]
  47. return image_list, label_list
  48. # 生成相同大小的批次
  49. def get_batch(image, label, image_W, image_H, batch_size, capacity):
  50. # image, label: 要生成batch的图像和标签list
  51. # image_W, image_H: 图片的宽高
  52. # batch_size: 每个batch有多少张图片
  53. # capacity: 队列容量
  54. # return: 图像和标签的batch
  55. # 将原来的python.list类型转换成tf能够识别的格式
  56. image = tf.cast(image, tf.string)#强制类型转换
  57. label = tf.cast(label, tf.int32)
  58. # 生成队列。我们使用slice_input_producer()来建立一个队列,将image和label放入一个list中当做参数传给该函数
  59. input_queue = tf.train.slice_input_producer([image, label])
  60. image_contents = tf.read_file(input_queue[0])
  61. # 按队列读数据和标签
  62. label = input_queue[1]
  63. image = tf.image.decode_jpeg(image_contents, channels=3)
  64. # 要按照图片格式进行解码。本例程中训练数据是jpg格式的,所以使用decode_jpeg()解码器,
  65. # 如果是其他格式,就要用其他geshi具体可以从官方API中查询。
  66. # 注意decode出来的数据类型是uint8,之后模型卷积层里面conv2d()要求输入数据为float32类型
  67. # 统一图片大小
  68. # 通过裁剪统一,包括裁剪和扩充
  69. # image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
  70. # 我的方法,通过缩小图片,采用NEAREST_NEIGHBOR插值方法
  71. image = tf.image.resize_images(image, [image_H, image_W], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
  72. align_corners=False)
  73. image = tf.cast(image, tf.float32)
  74. # 因为没有标准化,所以需要转换类型
  75. # image = tf.image.per_image_standardization(image) # 标准化数据
  76. image_batch, label_batch = tf.train.batch([image, label],
  77. batch_size=batch_size,
  78. num_threads=64, # 线程
  79. capacity=capacity)
  80. # image_batch是一个4D的tensor,[batch, width, height, channels],
  81. # label_batch是一个1D的tensor,[batch]。
  82. # 这行多余?
  83. label_batch = tf.reshape(label_batch, [batch_size])
  84. return image_batch, label_batch
  85. '''
  86. 下面代码为查看图片效果,主要用于观察图片是否打乱,你会可能会发现,图片显示出来的是一堆乱点,不用担心,这是因为你对图片的每一个像素进行了强制类型转化为了tf.float32,使像素值介于-1~1之间,若想看原图,可使用tf.uint8,像素介于0~255
  87. '''
  88. # print("yes")
  89. # image_list,label_list = get_files("E:\\Pycharm\\tf-01\\Bigwork\\train\\")
  90. # image_batch,label_batch = train_batch,train_label_batch = get_batch(image_list,label_list,208,208,4,256)
  91. # print("ok")
  92. #
  93. # for i in range(4):
  94. # with tf.Session() as sess:
  95. # i = 0
  96. # coord = tf.train.Coordinator()
  97. # threads = tf.train.start_queue_runners(coord=coord)
  98. # try:
  99. # while not coord.should_stop() and i < 1:
  100. # # just plot one batch size
  101. # image, label = sess.run([image_batch, label_batch])
  102. # for j in np.arange(4):
  103. # print('label: %d' % label[j])
  104. # plt.imshow(image[j, :, :, :])
  105. # plt.show()
  106. # i += 1
  107. # except tf.errors.OutOfRangeError:
  108. # print('done!')
  109. # finally:
  110. # coord.request_stop()
  111. # coord.join(threads)
  112. # for i in range(4):
  113. # sess = tf.Session()
  114. # image,label = sess.run([image_batch,label_batch])
  115. # for j in range(4):
  116. # print('label:%d' % label[j])
  117. # plt.imshow(image[j, :, :, :])
  118. # plt.show()
  119. # sess.close()

方法二

  1. import os
  2. import tensorflow as tf
  3. from PIL import Image
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import cv2
  7. cwd = "E:\\Pycharm\\tf-01\\Bigwork\\test\\"
  8. classes = {'cat', 'dog'} # 预先自己定义的类别
  9. writer = tf.python_io.TFRecordWriter('test.tfrecords') # 输出成tfrecord文件
  10. def _int64_feature(value):
  11. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  12. def _bytes_feature(value):
  13. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  14. for index, name in enumerate(classes):
  15. class_path = cwd + name + '\\'
  16. print(class_path)
  17. for img_name in os.listdir(class_path):
  18. img_path = class_path + img_name # 每个图片的地址
  19. img = Image.open(img_path)
  20. img = img.resize((208, 208))
  21. img_raw = img.tobytes() # 将图片转化为二进制格式
  22. example = tf.train.Example(features=tf.train.Features(feature={
  23. "label": _int64_feature(index),
  24. "img_raw": _bytes_feature(img_raw),
  25. }))
  26. writer.write(example.SerializeToString()) # 序列化为字符串
  27. writer.close()
  28. print("writed OK")
  29. #生成tfrecord文件后,下次可以不用再执行这段代码!!!
  30. def read_and_decode(filename,batch_size): # read train.tfrecords
  31. filename_queue = tf.train.string_input_producer([filename])
  32. reader = tf.TFRecordReader()
  33. _, serialized_example = reader.read(filename_queue)
  34. features = tf.parse_single_example(serialized_example,
  35. features={
  36. 'label': tf.FixedLenFeature([], tf.int64),
  37. 'img_raw': tf.FixedLenFeature([], tf.string),
  38. })
  39. img = tf.decode_raw(features['img_raw'], tf.float32)
  40. img = tf.reshape(img, [128, 128, 3]) # reshape image to 208*208*3
  41. #据说下面这行多余
  42. #img = tf.cast(img,tf.float32)*(1./255)-0.5
  43. label = tf.cast(features['label'], tf.int64)
  44. img_batch, label_batch = tf.train.shuffle_batch([img, label],
  45. batch_size=batch_size,
  46. num_threads = 8,
  47. capacity = 100,
  48. min_after_dequeue = 60,)
  49. return img_batch, tf.reshape(label_batch, [batch_size])
  50. filename = './/train.tfrecords'
  51. image_batch、label_batch = read_and_decode(filename,batch_size)

下一篇:关于神经网络的定义!(给出亲测的两个神经网络模型LeNet、AlexNet)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/123903
推荐阅读
相关标签
  

闽ICP备14008679号