当前位置:   article > 正文

[Tensorflow] 第四课 训练数据读取的几种方法_dataread.build_dataset_firstread

dataread.build_dataset_firstread

使用Tensorflow搭建卷积网络用于各种训练时,需要处理训练的图像和标签, 批量的输送给训练的网络。 Tensorflow训练数据的读取方法按我的理解可以分两类。
第一类,使用queue队列。第二类,使用tf.data.Dataset 对象。 第一类方法是传统的数据读取方法,使用简单,只需要两三行代码就可以实现,但缺点是数据需要完整的载入队列,对内存的消耗较大。因此,在使用一些比较小的数据集时,比如CIFAR-10,可以采用这种方法。第二类方法则是tensorflow 2.0 官方推荐的方法,更具灵活性。

第一类,使用queue(队列)

通过将训练数据放入一个队列里,使用tensorflow的队列方法来为训练提供数据。
可以使用tf.train.slice_input_producertf.train.batch配合来获得训练数据。具体方法,参考我的第一篇 Tensorflow系列文章 [Tensorflow]第一课 创建一个数据队列

这里再载介绍一个更最简单的方法,使用tf.train的string_input_producer 方法和 tf.WholeFileReader对象配合来读取训练图像。

1. tf.train.string_input_producer方法介绍
 tf.train.string_input_producer(
    string_tensor,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None,
    cancel_op=None
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

作用: 输出字符串到一个输入管道队列。
参数

  • string_tensor: 字符串列表
  • num_epochs:(可选)整数类型。读取字符串列表中每个字符串被放入队列的次数。如果不指定,则队列会循环使用字符串列表。
  • shuffle: (可选)布尔类型。True则乱顺序输出,False则按顺序输出。
  • seed:(可选)整数类型,如果shuffle 为True,则使用种子。
  • capacity:(可选)整数类型,设置队列容量。
  • shared_name:(可选)按照官方文档的说法,如果选择设置的话,这个队列可以被共享于多个session。在和这个queue同一设备上的session都可以通过这个共享名字来访问队列。官方文档中说,在分布式设置的程序里使用这个共享名字,意味着每个名字仅可以被一个session可见。
  • name:(可选)此操作的名称。
  • cancel_op:(可选)取消队列的操作。

输出:
一个字符串输出队列。 一个QueueRunner队列会被添加到当前的运算图的队列里。
范例:

__init__(name=None)
  • 1
2. tf.WholeFileReader对象介绍
Class WholeFileReader
  • 1

**对象介绍:**一个Reader对象,用于读取整个文件的内容作为value返回。

__init__(name=None)
  • 1

作用: 创建一个WholeFileReader对象。
范例:

Reader = tf.WholeFileReader()
  • 1

方法

tf.WholeFileReader.read(
	queue,
    name=None
)
  • 1
  • 2
  • 3
  • 4

作用: 返回读取的下一文件。
参数:

  • queue: 一个队列,或一个指向一个字符串队列的Tensor实例。
  • name:(可选)此操作的名称。
    返回值:
    一个元组包括2个Tensors对象 (key,value)
    key: Tensor对象,指向queue里的值。
    value: Tensor对象,指向读取文件内容
3. 使用方法

第一步,首先创建一个文件名的字符串列表。

filename = ['image1.png', 'image2.png','image3.png','image4.png','image5.png']
  • 1

第二步,调用tf.train.string_input_producer字符串列表放入队列中,获取队列的Tensor实例。

filename_queue = tf.train.string_input_producer(filename, 
				shuffle=False, 
				num_epochs=1)
  • 1
  • 2
  • 3

第三步,创建一个WholeFileReader对象,并且调用他的read方法,输入参数即之前获得的文件名队列,获得key和value两个Tensor实例。

reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
  • 1
  • 2

第四步,启动一个tf.Session, 并调用*tf.local_variables_initializer().run()*初始化变量。
要启动队列,首先要调用 threads = tf.train.start_queue_runners(sess=sess) 来初始化队列。
然后,就可以使用sess.run()方法来获得队列里的数据额

with tf.Session() as sess:
    tf.local_variables_initializer().run()
    # 使用start_queue_runners之后,才会开始填充队列
    threads = tf.train.start_queue_runners(sess=sess)
    try: 
        while True:
            imagefile, image_data = sess.run([key,value])
            print(imagefile)
    except: 
        print('End of files')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

代码实例

filename = ['image1.png', 'image2.png','image3.png','image4.png','image5.png']
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=1)
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)

with tf.Session() as sess:
    tf.local_variables_initializer().run()
    # 使用start_queue_runners之后,才会开始填充队列
    threads = tf.train.start_queue_runners(sess=sess)
    try: 
        i = 0
        while True:
            image_data = sess.run(value)
            i += 1
            (i)
    except: 
        print('End of files')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

第二类,使用tf.data.Dataset对象

这种数据提供法是Tensorflow 2.0 推荐的,更具灵活性。
此处, 我们有5个训练数据,5张图片和5个标签。
这5张图片存储在磁盘上,文件名则在imagelist里。标签则存在labels列表里。
第一步,我们要用到tf.data.Dataset对象,使用方法tf.data.Dataset.from_tensor_slices 来创建对象。此时,数据队列是一个图像文件名和一个标签。

labels = [0,1,2,3,4,5]
imagelist = ['0001.png','0204.png','0205.png','0206.png','0207.png', '0208.png']
dataset = tf.data.Dataset.from_tensor_slices((imagelist, labels))
  • 1
  • 2
  • 3

第二步,我们此处真正用于训练的是图像数据,而不是文件名,因此此处定义一个方法,通过文件名获得图像数据, 并且和标签配对。该方法的输入参数是一个文件名,和对应的标签值。此处需要注意的是,读取文件,图像解码和处理的方法都必须是tensorflow的方法,此处仅仅是定义tensorflow运算/操作,需要启动tf.Sessionrun方法,才会真正执行定义的操作。

def get_one_data(filename, label):
    image = tf.read_file(filename)
    image  = tf.image.decode_png(image)
    image  = tf.image.resize_images(image , [448, 448])    
    return image , label
  • 1
  • 2
  • 3
  • 4
  • 5

第三步,我们使用Dataset对象的map方法映射之前定义的get_one_data方法。此处,get_one_data方法就像一个回调函数一样被调用。

dataset = dataset.map(get_one_data)
  • 1

第四步,通过调用make_one_shot_iterator对象来获得一个iterator(迭代)对象,这个iterator对象每次返回一条记录。再调用iterator的*get_next()*来获得下一个元素。

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
  • 1
  • 2

第五步,开启一个tf.Session() as sess: 来启动tensorflow运算。sess.run的返回值是get_one_data的返回值,一个元组,第一个元素是图像数据,第二个元素是标签数据。

with tf.Session() as sess:
    try:
        while True:
            data_value = sess.run(one_element)
            print("Image Data")
            print(data_value[0])
            print("Label Data")
            print(data_value[1])
    except tf.errors.OutOfRangeError:
        print("end!")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

代码实例

def get_one_data(filename, label):
    image = tf.read_file(filename)
    image  = tf.image.decode_png(image)
    print(image)
    image  = tf.image.resize_images(image , [448, 448])    
    return image , label
    
labels = [0,1,2,3,4,5]
imagelist = ['0001.png','0204.png','0205.png','0206.png','0207.png', '0208.png']
dataset = tf.data.Dataset.from_tensor_slices((imagelist, labels))
dataset = dataset.map(get_one_data)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    try:
        i = 0
        while True:
            data_value = sess.run(one_element)
            print(data_value[0])
            print(data_value[1])
            i += 1
    except tf.errors.OutOfRangeError:
        print("end!")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/123862
推荐阅读
相关标签
  

闽ICP备14008679号