当前位置:   article > 正文

tensorflow分类任务——TFRecord读取自己制作的数据集_分类任务 tfrecorder

分类任务 tfrecorder

一、TensorFlow的数据读取机制

注意:这个地址是TensorFlow的数据读取机制,如果了解请跳过。

原博客地址:https://zhuanlan.zhihu.com/p/27238630

建议阅读博客https://blog.csdn.net/pursuit_zhangyu/article/details/80607529

代码地址:https://github.com/hzy46/Deep-Learning-21-Examples/tree/master/chapter_2

1.1关键函数解读

  1. 对于文件名队列,我们使用tf.train.string_input_producer函数。这个函数需要传入一个文件名list,系统会自动将它转为一个文件名队列。
  2.  reader = tf.TFRecordReader()创建读取
  3.  imageBatch, labelBatch = tf.train.shuffle_batch([image, labels], batch_size=batchSize,
                                                        capacity=capacity, min_after_dequeue=min_after_dequeue)打包读取,意思为小批次读取数据
  4. threads = tf.train.start_queue_runners(sess=sess, coord=coord)创建会话和多线程,启动读取

二、TFRecord读取数据集 

我的上一篇文章,我采用自己的图片制作了数据集,现在我写一下读取自己制作的数据集。

数据集地址:

链接:https://pan.baidu.com/s/1aIHzKsxUb67sJZAFrGH1ZQ 
提取码:lvjp 


工程地址:

链接:https://pan.baidu.com/s/1XGAA6UQ0JByhvDYQ__my4g 
提取码:dxpn 
 

  1. import numpy as np
  2. import tensorflow as tf
  3. batchSize = 15
  4. num_epochs = 20
  5. def tfRecordRead(fileNameQue, heigh, width, channels, n_class):
  6. reader = tf.TFRecordReader()
  7. # 创建一个队列来维护输入文件列表
  8. # 从文件中读出一个Example
  9. _, serialized_example = reader.read(fileNameQue)
  10. # 用FixedLenFeature将读入的Example解析成tensor
  11. features = tf.parse_single_example(
  12. serialized_example,
  13. features={
  14. 'image': tf.FixedLenFeature([], tf.string),
  15. 'label': tf.FixedLenFeature([], tf.int64)
  16. })
  17. # 将字符串解析成图像对应的像素数组
  18. image = tf.decode_raw(features['image'], tf.float32)
  19. # image = tf.decode_raw(features["image"], tf.uint8)
  20. image = tf.reshape(image, [heigh, width, channels])
  21. # image = tf.cast(image, tf.float32) * (1 / 255.0)
  22. labels = tf.cast(features['label'], tf.int64)
  23. labels = tf.one_hot(labels, n_class)
  24. return image, labels
  25. def tfRecordBatchRead(filename, heigh, width, channels, n_class, batchSize):
  26. fileNameQue = tf.train.string_input_producer([filename], shuffle=False, num_epochs=num_epochs)
  27. image, labels = tfRecordRead(fileNameQue, heigh, width, channels, n_class) # fetch图像和label
  28. min_after_dequeue = 1000
  29. capacity = min_after_dequeue + 3 * batchSize
  30. # 预取图像和label并随机打乱,组成batch,此时tensor rank发生了变化,多了一个batch大小的维度
  31. imageBatch, labelBatch = tf.train.shuffle_batch([image, labels], batch_size=batchSize,
  32. capacity=capacity, min_after_dequeue=min_after_dequeue)
  33. return imageBatch, labelBatch
  34. filename = r'./record\Imageoutput.tfrecords'
  35. # filename = 'Imageoutput.tfrecords'
  36. dataset = np.load('testData.npz')
  37. x_test = dataset['test_X'][1:20]
  38. y_test = dataset['test_Y'][1:20]
  39. heigh, width, channels, n_class = dataset['height'], dataset['width'], dataset['channels'], dataset['n_class']
  40. print(heigh, width, channels, n_class)
  41. imageBatch, labelBatch = tfRecordBatchRead(filename, heigh, width, channels, n_class, batchSize)
  42. # init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  43. with tf.Session() as sess:
  44. sess.run(tf.global_variables_initializer())
  45. sess.run(tf.local_variables_initializer())
  46. coord = tf.train.Coordinator()
  47. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  48. for i in range(num_epochs):
  49. example, label = sess.run([imageBatch, labelBatch])
  50. print(label)
  51. coord.request_stop()
  52. coord.join(threads)

上述代码把label输出了,运行结果如下: 

 

注意事项:

TFRecord读取数据集的过程中比前面说的TensorFlow数据读取机制多了一步:从TFRecord文件中解析出数据

  1. features = tf.parse_single_example(
  2. serialized_example,
  3. features={
  4. 'image': tf.FixedLenFeature([], tf.string),
  5. 'label': tf.FixedLenFeature([], tf.int64)
  6. })
  7. # 将字符串解析成图像对应的像素数组
  8. image = tf.decode_raw(features['image'], tf.float32)
  9. # image = tf.decode_raw(features["image"], tf.uint8)
  10. image = tf.reshape(image, [heigh, width, channels])
  11. # image = tf.cast(image, tf.float32) * (1 / 255.0)
  12. labels = tf.cast(features['label'], tf.int64)
  13. labels = tf.one_hot(labels, n_class)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/123910?site
推荐阅读
相关标签
  

闽ICP备14008679号