赞
踩
使用Tensorflow搭建卷积网络用于各种训练时,需要处理训练的图像和标签, 批量的输送给训练的网络。 Tensorflow训练数据的读取方法按我的理解可以分两类。
第一类,使用queue队列。第二类,使用tf.data.Dataset 对象。 第一类方法是传统的数据读取方法,使用简单,只需要两三行代码就可以实现,但缺点是数据需要完整的载入队列,对内存的消耗较大。因此,在使用一些比较小的数据集时,比如CIFAR-10,可以采用这种方法。第二类方法则是tensorflow 2.0 官方推荐的方法,更具灵活性。
通过将训练数据放入一个队列里,使用tensorflow的队列方法来为训练提供数据。
可以使用tf.train.slice_input_producer和tf.train.batch配合来获得训练数据。具体方法,参考我的第一篇 Tensorflow系列文章 [Tensorflow]第一课 创建一个数据队列
这里再载介绍一个更最简单的方法,使用tf.train的string_input_producer 方法和 tf.WholeFileReader对象配合来读取训练图像。
tf.train.string_input_producer(
string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None
)
作用: 输出字符串到一个输入管道队列。
参数
输出:
一个字符串输出队列。 一个QueueRunner队列会被添加到当前的运算图的队列里。
范例:
__init__(name=None)
Class WholeFileReader
**对象介绍:**一个Reader对象,用于读取整个文件的内容作为value返回。
__init__(name=None)
作用: 创建一个WholeFileReader对象。
范例:
Reader = tf.WholeFileReader()
方法
tf.WholeFileReader.read(
queue,
name=None
)
作用: 返回读取的下一文件。
参数:
第一步,首先创建一个文件名的字符串列表。
filename = ['image1.png', 'image2.png','image3.png','image4.png','image5.png']
第二步,调用tf.train.string_input_producer字符串列表放入队列中,获取队列的Tensor实例。
filename_queue = tf.train.string_input_producer(filename,
shuffle=False,
num_epochs=1)
第三步,创建一个WholeFileReader对象,并且调用他的read方法,输入参数即之前获得的文件名队列,获得key和value两个Tensor实例。
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
第四步,启动一个tf.Session, 并调用*tf.local_variables_initializer().run()*初始化变量。
要启动队列,首先要调用 threads = tf.train.start_queue_runners(sess=sess) 来初始化队列。
然后,就可以使用sess.run()方法来获得队列里的数据额
with tf.Session() as sess:
tf.local_variables_initializer().run()
# 使用start_queue_runners之后,才会开始填充队列
threads = tf.train.start_queue_runners(sess=sess)
try:
while True:
imagefile, image_data = sess.run([key,value])
print(imagefile)
except:
print('End of files')
代码实例
filename = ['image1.png', 'image2.png','image3.png','image4.png','image5.png'] filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=1) reader = tf.WholeFileReader() key, value = reader.read(filename_queue) with tf.Session() as sess: tf.local_variables_initializer().run() # 使用start_queue_runners之后,才会开始填充队列 threads = tf.train.start_queue_runners(sess=sess) try: i = 0 while True: image_data = sess.run(value) i += 1 (i) except: print('End of files')
这种数据提供法是Tensorflow 2.0 推荐的,更具灵活性。
此处, 我们有5个训练数据,5张图片和5个标签。
这5张图片存储在磁盘上,文件名则在imagelist里。标签则存在labels列表里。
第一步,我们要用到tf.data.Dataset对象,使用方法tf.data.Dataset.from_tensor_slices 来创建对象。此时,数据队列是一个图像文件名和一个标签。
labels = [0,1,2,3,4,5]
imagelist = ['0001.png','0204.png','0205.png','0206.png','0207.png', '0208.png']
dataset = tf.data.Dataset.from_tensor_slices((imagelist, labels))
第二步,我们此处真正用于训练的是图像数据,而不是文件名,因此此处定义一个方法,通过文件名获得图像数据, 并且和标签配对。该方法的输入参数是一个文件名,和对应的标签值。此处需要注意的是,读取文件,图像解码和处理的方法都必须是tensorflow的方法,此处仅仅是定义tensorflow运算/操作,需要启动tf.Session的run方法,才会真正执行定义的操作。
def get_one_data(filename, label):
image = tf.read_file(filename)
image = tf.image.decode_png(image)
image = tf.image.resize_images(image , [448, 448])
return image , label
第三步,我们使用Dataset对象的map方法映射之前定义的get_one_data方法。此处,get_one_data方法就像一个回调函数一样被调用。
dataset = dataset.map(get_one_data)
第四步,通过调用make_one_shot_iterator对象来获得一个iterator(迭代)对象,这个iterator对象每次返回一条记录。再调用iterator的*get_next()*来获得下一个元素。
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
第五步,开启一个tf.Session() as sess: 来启动tensorflow运算。sess.run的返回值是get_one_data的返回值,一个元组,第一个元素是图像数据,第二个元素是标签数据。
with tf.Session() as sess:
try:
while True:
data_value = sess.run(one_element)
print("Image Data")
print(data_value[0])
print("Label Data")
print(data_value[1])
except tf.errors.OutOfRangeError:
print("end!")
代码实例
def get_one_data(filename, label): image = tf.read_file(filename) image = tf.image.decode_png(image) print(image) image = tf.image.resize_images(image , [448, 448]) return image , label labels = [0,1,2,3,4,5] imagelist = ['0001.png','0204.png','0205.png','0206.png','0207.png', '0208.png'] dataset = tf.data.Dataset.from_tensor_slices((imagelist, labels)) dataset = dataset.map(get_one_data) iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() with tf.Session() as sess: try: i = 0 while True: data_value = sess.run(one_element) print(data_value[0]) print(data_value[1]) i += 1 except tf.errors.OutOfRangeError: print("end!")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。