赞
踩
首先需要安装一个独立的Python包提供支持:
pip install tensorflow-datasets
导入mnist数据集示例:
# 导入相关包
import tensorflow as tf
import tensorflow_datasets as tfds
# 最基础的方法tfds.load载入
dataset = tfds.load("mnist", split=tfds.Split.TRAIN, as_supervised=True)
说明:
tfds.load
返回的是一个tf.data.Dataset
类型的对象,由一些列的可迭代访问的元素(element)组成,每个元素包含一个或多个张量。比如说,对于一个由图像组成的数据集,每个元素可以是一个形状为 长×宽×通道数
的图片张量,也可以是由图片张量和图片标签张量组成的元组(Tuple)。
as_supervised
参数:若为True,则根据数据集的特性,将数据集中的每行元素整理为有监督的二元组 (input, label)
(即 “数据 + 标签”)形式,否则数据集中的每行元素为包含所有特征的字典。split
:指定返回数据集的特定部分。若不指定,则返回整个数据集。一般有 tfds.Split.TRAIN
(训练集)和 tfds.Split.TEST
(测试集)选项。tf.data.Dataset.from_tensor_slices()
方法。
具体而言,如果我们的数据集中的所有元素通过张量的第 0 维,拼接成一个大的张量(例如,前节的 MNIST 数据集的训练集即为一个 [60000, 28, 28, 1]
的张量,表示了 60000 张 28*28 的单通道灰度图像),那么我们提供一个这样的张量或者第 0 维大小相同的多个张量作为输入,即可按张量的第 0 维展开来构建数据集,数据集的元素数量为张量第 0 维的大小。具体示例如下:
import tensorflow as tf import numpy as np X = tf.constant([2013, 2014, 2015, 2016, 2017]) Y = tf.constant([12000, 14000, 15000, 16500, 17500]) # 也可以使用NumPy数组,效果相同 # X = np.array([2013, 2014, 2015, 2016, 2017]) # Y = np.array([12000, 14000, 15000, 16500, 17500]) dataset = tf.data.Dataset.from_tensor_slices((X, Y)) for x, y in dataset: print(x.numpy(), y.numpy()) #-----------输出------------- 2013 12000 2014 14000 2015 15000 2016 16500 2017 17500
零维张量:指的是张量内只有一个元素,相当于标量,没有方向。
解释:可见创建了x,y两个零维张量,每个张量里边有5个元素,该方法会将这两个张量分片为5个元素并一一对应拼接构成5个二元组。因此,当多个张量作为输入时,张量的第零维大小必须相同,才能一一对应上构成元组。
可以先将数据集处理为 TFRecord
格式,然后使用 tf.data.TFRocordDataset()
进行载入。
TFRecord 是 TensorFlow 中的数据集存储格式。当我们将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而帮助我们更高效地进行大规模的模型训练。
暂时用不上,先不做考虑。
tf.data.Dataset
类为我们提供了多种数据集预处理方法。最常用的如:
Dataset.map(f)
:对数据集中的每个元素应用函数 f
,得到一个新的数据集(这部分往往结合 tf.io
进行读写和解码文件, tf.image
进行图像处理)Dataset.shuffle(buffer_size)
:将数据集打乱(设定一个固定大小的缓冲区(Buffer),取出前 buffer_size
个元素放入,并从缓冲区中随机采样,采样后的数据用后续数据替换)Dataset.batch(batch_size)
:将数据集分成批次,即对每 batch_size
个元素,使用 tf.stack()
在第 0 维合并,成为一个元素;Dataset.repeat()
:重复数据集的元素Dataset.map(f)
示例代码:
def rot90(image, label):
image = tf.image.rot90(image)
return image, label
mnist_dataset = mnist_dataset.map(rot90)
for image, label in mnist_dataset:
plt.title(label.numpy())
plt.imshow(image.numpy()[:, :, 0])
plt.show()
解释:def rot90
即为定义了一个将图片旋转90°的函数f
,然后作为参数传入mnist_dataset.map(rot90)
方法中,实现了对数据集中每个元素(图片)旋转90°。
Dataset.batch(batch_size)
以下示例代码将数据集按每个批次大小为4划分批次:
mnist_dataset = mnist_dataset.batch(4)
for images, labels in mnist_dataset: # image: [4, 28, 28, 1], labels: [4]
fig, axs = plt.subplots(1, 4)
for i in range(4):
axs[i].set_title(labels.numpy()[i])
axs[i].imshow(images.numpy()[i, :, :, 0])
plt.show()
解释:源数据集的大小是[60000, 28, 28, 1]
,有60000张28*28的单通道灰度图片;分批次之后变成[4, 28, 28, 1]
,变成了只有4张图片。
Dataset.shuffle(buffer_size)
以下示例代码将数据集打乱之后再设置批次,缓存大小设置为10000(即缓存了10000张图片,采样时从缓冲区随机取样):
mnist_dataset = mnist_dataset.shuffle(buffer_size=10000).batch(4)
for images, labels in mnist_dataset:
fig, axs = plt.subplots(1, 4)
for i in range(4):
axs[i].set_title(labels.numpy()[i])
axs[i].imshow(images.numpy()[i, :, :, 0])
plt.show()
缓冲区具体设置细节:
buffer_size
个元素放入缓冲区缓冲区大小buffer_size
的设置:
Dataset.prefetch()
当训练模型时,我们希望充分利用计算资源,减少 CPU/GPU 的空载时间。然而有时,数据集的准备处理非常耗时,使得我们在每进行一次训练前都需要花费大量的时间准备待训练的数据,而此时 GPU 只能空载而等待数据,造成了计算资源的浪费,如下图所示:
常规训练流程,在准备数据时,GPU 只能空载。
使用 Dataset.prefetch() 方法进行数据预加载后的训练流程,在 GPU 进行训练的同时 CPU 进行数据预加载,提高了训练效率。如下图所示:
代码示例:
mnist_dataset = mnist_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
此处参数 buffer_size
既可手工设置,也可设置为 tf.data.experimental.AUTOTUNE
从而由 TensorFlow 自动选择合适的数值。
tf.data.Dataset
是一个 Python 的可迭代对象,因此可以使用 For 循环迭代获取数据,即:
dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
for a, b, c, ... in dataset:
# 对张量a, b, c等进行操作,例如送入模型进行训练
使用 iter()
显式创建一个 Python 迭代器并使用 next()
获取下一个元素,即:
dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
it = iter(dataset)
a_0, b_0, c_0, ... = next(it)
a_1, b_1, c_1, ... = next(it)
常规keras训练需要将x和y都指定,分别为训练数据和训练标签:
model.fit(x=train_data, y=train_label, epochs=num_epochs, batch_size=batch_size)
tf.data.Dataset
直接作为输入当调用 tf.keras.Model
的 fit()
和 evaluate()
方法时,可以将参数中的输入数据 x
指定为一个元素格式为 (输入数据, 标签数据) 的 Dataset
,并忽略掉参数中的标签数据 y
。
示例代码如下:
model.fit(mnist_dataset, epochs=num_epochs)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。