赞
踩
注意:这个地址是TensorFlow的数据读取机制,如果了解请跳过。
原博客地址:https://zhuanlan.zhihu.com/p/27238630
建议阅读博客:https://blog.csdn.net/pursuit_zhangyu/article/details/80607529
代码地址:https://github.com/hzy46/Deep-Learning-21-Examples/tree/master/chapter_2
我的上一篇文章,我采用自己的图片制作了数据集,现在我写一下读取自己制作的数据集。
数据集地址:
链接:https://pan.baidu.com/s/1aIHzKsxUb67sJZAFrGH1ZQ
提取码:lvjp
工程地址:
链接:https://pan.baidu.com/s/1XGAA6UQ0JByhvDYQ__my4g
提取码:dxpn
- import numpy as np
- import tensorflow as tf
-
- batchSize = 15
- num_epochs = 20
-
-
- def tfRecordRead(fileNameQue, heigh, width, channels, n_class):
- reader = tf.TFRecordReader()
- # 创建一个队列来维护输入文件列表
- # 从文件中读出一个Example
- _, serialized_example = reader.read(fileNameQue)
- # 用FixedLenFeature将读入的Example解析成tensor
- features = tf.parse_single_example(
- serialized_example,
- features={
- 'image': tf.FixedLenFeature([], tf.string),
- 'label': tf.FixedLenFeature([], tf.int64)
- })
- # 将字符串解析成图像对应的像素数组
- image = tf.decode_raw(features['image'], tf.float32)
- # image = tf.decode_raw(features["image"], tf.uint8)
- image = tf.reshape(image, [heigh, width, channels])
- # image = tf.cast(image, tf.float32) * (1 / 255.0)
- labels = tf.cast(features['label'], tf.int64)
- labels = tf.one_hot(labels, n_class)
- return image, labels
-
-
- def tfRecordBatchRead(filename, heigh, width, channels, n_class, batchSize):
- fileNameQue = tf.train.string_input_producer([filename], shuffle=False, num_epochs=num_epochs)
- image, labels = tfRecordRead(fileNameQue, heigh, width, channels, n_class) # fetch图像和label
- min_after_dequeue = 1000
- capacity = min_after_dequeue + 3 * batchSize
- # 预取图像和label并随机打乱,组成batch,此时tensor rank发生了变化,多了一个batch大小的维度
- imageBatch, labelBatch = tf.train.shuffle_batch([image, labels], batch_size=batchSize,
- capacity=capacity, min_after_dequeue=min_after_dequeue)
- return imageBatch, labelBatch
-
-
- filename = r'./record\Imageoutput.tfrecords'
- # filename = 'Imageoutput.tfrecords'
-
-
- dataset = np.load('testData.npz')
- x_test = dataset['test_X'][1:20]
- y_test = dataset['test_Y'][1:20]
-
- heigh, width, channels, n_class = dataset['height'], dataset['width'], dataset['channels'], dataset['n_class']
- print(heigh, width, channels, n_class)
-
- imageBatch, labelBatch = tfRecordBatchRead(filename, heigh, width, channels, n_class, batchSize)
- # init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
-
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- sess.run(tf.local_variables_initializer())
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- for i in range(num_epochs):
- example, label = sess.run([imageBatch, labelBatch])
- print(label)
- coord.request_stop()
- coord.join(threads)
上述代码把label输出了,运行结果如下:
注意事项:
TFRecord读取数据集的过程中比前面说的TensorFlow数据读取机制多了一步:从TFRecord文件中解析出数据
- features = tf.parse_single_example(
- serialized_example,
- features={
- 'image': tf.FixedLenFeature([], tf.string),
- 'label': tf.FixedLenFeature([], tf.int64)
- })
- # 将字符串解析成图像对应的像素数组
- image = tf.decode_raw(features['image'], tf.float32)
- # image = tf.decode_raw(features["image"], tf.uint8)
- image = tf.reshape(image, [heigh, width, channels])
- # image = tf.cast(image, tf.float32) * (1 / 255.0)
- labels = tf.cast(features['label'], tf.int64)
- labels = tf.one_hot(labels, n_class)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。