当前位置:   article > 正文

tensorflow2.x 数据集相关知识和操作_tensorflow_datasets.load('mnist')

tensorflow_datasets.load('mnist')

1 数据集载入

1.1 载入开箱即用的数据集

首先需要安装一个独立的Python包提供支持:

pip install tensorflow-datasets
  • 1

导入mnist数据集示例:

# 导入相关包
import tensorflow as tf
import tensorflow_datasets as tfds

# 最基础的方法tfds.load载入
dataset = tfds.load("mnist", split=tfds.Split.TRAIN, as_supervised=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

说明:
tfds.load返回的是一个tf.data.Dataset类型的对象,由一些列的可迭代访问的元素(element)组成,每个元素包含一个或多个张量。比如说,对于一个由图像组成的数据集,每个元素可以是一个形状为 长×宽×通道数 的图片张量,也可以是由图片张量和图片标签张量组成的元组(Tuple)。

  1. as_supervised参数:若为True,则根据数据集的特性,将数据集中的每行元素整理为有监督的二元组 (input, label) (即 “数据 + 标签”)形式,否则数据集中的每行元素为包含所有特征的字典。
  2. split:指定返回数据集的特定部分。若不指定,则返回整个数据集。一般有 tfds.Split.TRAIN (训练集)和 tfds.Split.TEST (测试集)选项。

1.2 数据量较少的数据集

tf.data.Dataset.from_tensor_slices()方法。

具体而言,如果我们的数据集中的所有元素通过张量的第 0 维,拼接成一个大的张量(例如,前节的 MNIST 数据集的训练集即为一个 [60000, 28, 28, 1] 的张量,表示了 60000 张 28*28 的单通道灰度图像),那么我们提供一个这样的张量或者第 0 维大小相同的多个张量作为输入,即可按张量的第 0 维展开来构建数据集,数据集的元素数量为张量第 0 维的大小。具体示例如下:

import tensorflow as tf
import numpy as np

X = tf.constant([2013, 2014, 2015, 2016, 2017])
Y = tf.constant([12000, 14000, 15000, 16500, 17500])

# 也可以使用NumPy数组,效果相同
# X = np.array([2013, 2014, 2015, 2016, 2017])
# Y = np.array([12000, 14000, 15000, 16500, 17500])

dataset = tf.data.Dataset.from_tensor_slices((X, Y))

for x, y in dataset:
    print(x.numpy(), y.numpy()) 

#-----------输出-------------
2013 12000
2014 14000
2015 15000
2016 16500
2017 17500
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

零维张量:指的是张量内只有一个元素,相当于标量,没有方向。

解释:可见创建了x,y两个零维张量,每个张量里边有5个元素,该方法会将这两个张量分片为5个元素并一一对应拼接构成5个二元组。因此,当多个张量作为输入时,张量的第零维大小必须相同,才能一一对应上构成元组。

1.3 数据量特别巨大而无法完整载入内存的数据集

可以先将数据集处理为 TFRecord 格式,然后使用 tf.data.TFRocordDataset() 进行载入。

TFRecord 是 TensorFlow 中的数据集存储格式。当我们将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而帮助我们更高效地进行大规模的模型训练。

暂时用不上,先不做考虑。

参考文档

2 数据集预处理

tf.data.Dataset 类为我们提供了多种数据集预处理方法。最常用的如:

  1. Dataset.map(f):对数据集中的每个元素应用函数 f ,得到一个新的数据集(这部分往往结合 tf.io 进行读写和解码文件, tf.image 进行图像处理)
  2. Dataset.shuffle(buffer_size) :将数据集打乱(设定一个固定大小的缓冲区(Buffer),取出前 buffer_size 个元素放入,并从缓冲区中随机采样,采样后的数据用后续数据替换)
  3. Dataset.batch(batch_size) :将数据集分成批次,即对每 batch_size 个元素,使用 tf.stack() 在第 0 维合并,成为一个元素;
  4. Dataset.repeat():重复数据集的元素

2.1 Dataset.map(f)

示例代码:

def rot90(image, label):
    image = tf.image.rot90(image)
    return image, label

mnist_dataset = mnist_dataset.map(rot90)

for image, label in mnist_dataset:
    plt.title(label.numpy())
    plt.imshow(image.numpy()[:, :, 0])
    plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

解释:def rot90即为定义了一个将图片旋转90°的函数f,然后作为参数传入mnist_dataset.map(rot90)方法中,实现了对数据集中每个元素(图片)旋转90°。

2.2 Dataset.batch(batch_size)

以下示例代码将数据集按每个批次大小为4划分批次:

mnist_dataset = mnist_dataset.batch(4)

for images, labels in mnist_dataset:    # image: [4, 28, 28, 1], labels: [4]
    fig, axs = plt.subplots(1, 4)
    for i in range(4):
        axs[i].set_title(labels.numpy()[i])
        axs[i].imshow(images.numpy()[i, :, :, 0])
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

解释:源数据集的大小是[60000, 28, 28, 1],有60000张28*28的单通道灰度图片;分批次之后变成[4, 28, 28, 1],变成了只有4张图片。

2.3 Dataset.shuffle(buffer_size)

以下示例代码将数据集打乱之后再设置批次,缓存大小设置为10000(即缓存了10000张图片,采样时从缓冲区随机取样):

mnist_dataset = mnist_dataset.shuffle(buffer_size=10000).batch(4)

for images, labels in mnist_dataset:
    fig, axs = plt.subplots(1, 4)
    for i in range(4):
        axs[i].set_title(labels.numpy()[i])
        axs[i].imshow(images.numpy()[i, :, :, 0])
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

缓冲区具体设置细节:

  1. 设定一个固定大小为 buffer_size 的缓冲区(Buffer)
  2. 初始化时,取出数据集中的前 buffer_size 个元素放入缓冲区
  3. 每次需要从数据集中取元素时,即从缓冲区中随机采样一个元素并取出,然后从后续的元素中取出一个放回到之前被取出的位置,以维持缓冲区的大小。

缓冲区大小buffer_size的设置:

  1. 缓冲区的大小需要根据数据集的 特性数据排列顺序 特点来进行合理的设置。
  2. 当数据集的标签顺序分布极为不均匀(例如二元分类时数据集前 N 个的标签为 0,后 N 个的标签为 1)时,较小的缓冲区大小会使得训练时取出的 Batch 数据很可能全为同一标签,从而影响训练效果。一般而言,数据集的顺序分布若较为随机,则缓冲区的大小可较小,否则则需要设置较大的缓冲区。

2.4 Dataset.prefetch()

当训练模型时,我们希望充分利用计算资源,减少 CPU/GPU 的空载时间。然而有时,数据集的准备处理非常耗时,使得我们在每进行一次训练前都需要花费大量的时间准备待训练的数据,而此时 GPU 只能空载而等待数据,造成了计算资源的浪费,如下图所示:请添加图片描述
常规训练流程,在准备数据时,GPU 只能空载。

使用 Dataset.prefetch() 方法进行数据预加载后的训练流程,在 GPU 进行训练的同时 CPU 进行数据预加载,提高了训练效率。如下图所示:
请添加图片描述
代码示例:

mnist_dataset = mnist_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  • 1

此处参数 buffer_size 既可手工设置,也可设置为 tf.data.experimental.AUTOTUNE 从而由 TensorFlow 自动选择合适的数值。

3 数据集元素的获取与使用

3.1 for循环迭代获取

tf.data.Dataset 是一个 Python 的可迭代对象,因此可以使用 For 循环迭代获取数据,即:

dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
for a, b, c, ... in dataset:
    # 对张量a, b, c等进行操作,例如送入模型进行训练
  • 1
  • 2
  • 3

3.2 创建Python迭代器

使用 iter() 显式创建一个 Python 迭代器并使用 next() 获取下一个元素,即:

dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
it = iter(dataset)
a_0, b_0, c_0, ... = next(it)
a_1, b_1, c_1, ... = next(it)
  • 1
  • 2
  • 3
  • 4

3.3 常规的keras传入数据集训练

常规keras训练需要将x和y都指定,分别为训练数据和训练标签:

model.fit(x=train_data, y=train_label, epochs=num_epochs, batch_size=batch_size)
  • 1

3.4 使用 tf.data.Dataset 直接作为输入

当调用 tf.keras.Modelfit()evaluate() 方法时,可以将参数中的输入数据 x 指定为一个元素格式为 (输入数据, 标签数据) 的 Dataset ,并忽略掉参数中的标签数据 y
示例代码如下:

model.fit(mnist_dataset, epochs=num_epochs)
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/123858
推荐阅读
相关标签
  

闽ICP备14008679号