赞
踩
参考 基于tensorflow的图像处理(四) 数据集处理 - 云+社区 - 腾讯云
除队列以外,tensorflow还提供了一套更高的数据处理框架。在新的框架中,每一个数据来源被抽象成一个“数据集”,开发者可以以数据集为基本对象,方便地进行batching、随机打乱(shuffle)等操作。
在数据集框架中,每一个数据集代表一个数据来源:数据可能来自一个张量,一个TFRecord文件,一个文本文件,或者经过sharding的一系列文件,等等。由于训练数据集通常无法全部写入内存中,从数据中读取数据时需要使用一个迭代器(iterator)按顺序进行读取,这点与队列的dequeue()操作和Reader的read()操作相似。与队列相似,数据集也是计算图上的一个点。
下面先看一个简单的例子,这个例子从一个张量创建一个数据集,遍历这个数据集,并对每个输入输出y=x^2的值。
- import tensorflow as tf
-
- # 从一个数组创建数据集。
- input_data = [1, 2, 3, 4, 5, 6]
- dataset = tf.data.Dataset.from_tensor_slices(input_data)
-
-
- # 定义一个迭代器用于遍历数据集。因为上面定义的数据集没有用placeholder
- # 作为输入参数,所以这里可以使用最简单的one_shot_iterator。
- iterator = dataset.make_one_shot_iterator()
-
- # get_next() 返回代表一个输入数据的张量,类似于队列中dequeue()。
-
- x = iterator.get_next()
- y = x * x
-
-
- with tf.Session() as sess:
- for i in range(len(input_data)):
- print(sess.run(y))
-
-
- 输出:
- ---
- 1
- 4
- 9
- 16
- 25
- 36
- ---
从以上例子可以看到,利用数据集读取数据有三个基本步骤。
1.定义数据集的构造方法
这个例子使用了tf.data.Dataset.from_tensor_slice(),表明数据集是从一个张量中构建的。如果数据集是从文件中构建的,则需要相应调用不同的构造方法。
2.定义遍历器
这个例子使用了最简单的one_shot_iterator来遍历数据集。
3.使用get_next()方法从遍历器中读取数据张量,作为计算图其他部分的输入
在图像相关任务中,输入数据通常以TFRecord形式存储,这时可以用TFRecordDataset来读取数据。与文本文件不同, 每一个TFRecord都有自己不同的feature格式,因此在读取TFRecord时,需要提供一个parser函数来解析所读取的TFRecord的数据格式。
- import tensorflow as tf
-
- # 解析一个TFRecord的方法。record是从文件中读取的一个样例。
- def parser(record):
- # 解析读入的一个样例
- features = tf.parse_single_example(
- record,
- features={
- 'feat1': tf.FixedLenFeature([], tf.int64),
- 'feat2': tf.FixedLenFeature([], tf.int64),
- })
- return features['feat1'], features['feat2']
-
-
- # 从TFRecord文件创建数据集
- input_files = ["/path/to/input_file1", "/path/to/input_fi
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。