当前位置:   article > 正文

TensorFlow下如何将图片制作成数据集_将分好类的图片制作tensorflow可以使用的数据集

将分好类的图片制作tensorflow可以使用的数据集

引言:

在做TensorFlow案例时发现好多的图片数据集都是处理好的,直接在库中调用的。比如Mnist,CIFAR-10等等。但是在跑自己项目的时候如何去读取自己的数据集呢?其实,一方面TensorFlow官方已经给出方法,那就是将图片制作成tfrecord格式的数据,供TensorFlow读取。另一方面Python以及Python的图像处理第三方库都有读取制作的方法,种类繁杂。

下面我将介绍两种方法:1.用python制作数据集2.基于TensorFlow制作tfrecord格式的数据集

一 用python制作数据集

代码比较简单这里做一下简单的说明:

1.一定要把.py文件放到图片所在的文件夹内,因为程序获取的路径是.py文件下的路径,但是你的源图片路径也得有图片否则回报错(目前是什么原因造成的还没发现,以后补充)。

2.程序已经写成函数了,所以只需要把图片路径以及将图片放到.py文件下就行了。参数有路径path和需要制作的标签Lables。

直接上代码:

  1. import os
  2. import matplotlib.pyplot as plt
  3. import matplotlib.image as mpimg
  4. import numpy as np
  5. def make_data(path,labels):
  6. def getAllimages(folder):
  7. assert os.path.exists(folder)
  8. assert os.path.isdir(folder)
  9. imageList = os.listdir(folder)
  10. imageList = [os.path.abspath(item) for item in imageList if os.path.isfile(os.path.join(folder, item))]
  11. return imageList
  12. ImageList=getAllimages(path)
  13. TrainList=[]
  14. Lable=[]
  15. Img_data=[]
  16. for i in range(len(ImageList)):
  17. string=str(ImageList[i])
  18. List=mpimg.imread(string)
  19. TrainList.append(List)
  20. Lable1=labels
  21. Lable.append(Lable1)
  22. Img = np.hstack((TrainList, Lable))
  23. Img_data=Img[:len(TrainList)]
  24. Img_lable=Img[len(TrainList):]
  25. return Img_data,Img_lable
  26. path=(r'/home/wcy/图片')
  27. img,lable=make_data(path,0)
  28. print(lable)

注意:/home/wcy/图片目录下有需要制作的图片以及.py文件夹下也应该有图片。

二 基于TensorFlow制作tfrecord格式的数据集

整个程序分为两部分一个是make_image_TFRecord另一部分是read_Tfrecord。

1.make_image_TFRecord.py

  1. import os
  2. import tensorflow as tf
  3. from PIL import Image
  4. import numpy as np
  5. import pandas as pd
  6. # 原始图片的存储位置
  7. orig_picture = os.getcwd()+'\\image\\test'
  8. # 生成图片的存储位置
  9. gen_picture = os.getcwd()+'\\image'
  10. # 需要的识别类型
  11. classes = {'0', '1'}
  12. # 样本总数
  13. num_samples = 40
  14. # 制作TFRecords数据
  15. def create_record():
  16. writer = tf.python_io.TFRecordWriter("test.tfrecords")
  17. for index, name in enumerate(classes):
  18. class_path = orig_picture + "/" + name + "/"
  19. for img_name in os.listdir(class_path):
  20. img_path = class_path + img_name
  21. img = Image.open(img_path)
  22. img = img.resize((32, 32)) # 设置需要转换的图片大小
  23. ###图片灰度化######################################################################
  24. # img=img.convert("L")
  25. ##############################################################################################
  26. img_raw = img.tobytes() # 将图片转化为原生bytes
  27. example = tf.train.Example(
  28. features=tf.train.Features(feature={
  29. "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
  30. 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
  31. }))
  32. writer.write(example.SerializeToString())
  33. writer.close()
  34. # =======================================================================================
  35. def read_and_decode(filename,is_batch):
  36. # 创建文件队列,不限读取的数量
  37. filename_queue = tf.train.string_input_producer([filename])
  38. # create a reader from file queue
  39. reader = tf.TFRecordReader()
  40. # reader从文件队列中读入一个序列化的样本
  41. _, serialized_example = reader.read(filename_queue)
  42. # get feature from serialized example
  43. # 解析符号化的样本
  44. features = tf.parse_single_example(
  45. serialized_example,
  46. features={
  47. 'label': tf.FixedLenFeature([], tf.int64),
  48. 'img_raw': tf.FixedLenFeature([], tf.string)
  49. })
  50. label = features['label']
  51. img = features['img_raw']
  52. img = tf.decode_raw(img, tf.uint8)
  53. img = tf.reshape(img, [32, 32, 3])
  54. # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  55. label = tf.cast(label, tf.int32)
  56. if is_batch:
  57. batch_size = 3
  58. min_after_dequeue = 10
  59. capacity = min_after_dequeue + 3 * batch_size
  60. img, label = tf.train.shuffle_batch([img, label],
  61. batch_size=batch_size,
  62. num_threads=3,
  63. capacity=capacity,
  64. min_after_dequeue=min_after_dequeue)
  65. return img, label
  66. # =======================================================================================

2.read_Tfrecord.py

  1. import tensorflow as tf
  2. import os
  3. import pandas as pd
  4. from make_image_TFRecord import create_record
  5. from make_image_TFRecord import read_and_decode
  6. from PIL import Image
  7. num_samples = 40
  8. create_record()
  9. train_image, train_label = read_and_decode('test.tfrecords', is_batch=False)
  10. # 初始化变量
  11. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  12. # 创建一个session用于run输出结果
  13. with tf.Session() as sess: # 开始一个会话
  14. sess.run(init_op)
  15. coord = tf.train.Coordinator()
  16. threads = tf.train.start_queue_runners(coord=coord)
  17. data = pd.DataFrame()
  18. for i in range(num_samples):
  19. example, lab = sess.run([train_image, train_label]) # 在会话中取出image和label
  20. img = Image.fromarray(example, 'RGB') # 这里Image是之前提到的
  21. # img.save(gen_picture + '/' + str(i) + 'samples' + str(lab) + '.jpg') # 存下图片;注意cwd后边加上‘/’
  22. # img.save( '/' + str(i) + 'samples' + str(lab) + '.jpg') # 存下图片;注意cwd后边加上‘/’
  23. # print(example, lab)
  24. print(lab)
  25. coord.request_stop()
  26. coord.join(threads)
  27. sess.close() # 关闭会话
  28. # ========================================================================================

第一个程序运行完之后会生成一个.tfrecords格式的文件,然后再第二个程序中直接读取调用就行。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/123815
推荐阅读
相关标签
  

闽ICP备14008679号