当前位置:   article > 正文

Pytorh学习——DataSet和DataLoader

dataset和dataloader

目录

Pytorch的数据集

DataSet

DataLoader

创建自定义数据集

参考文档


Pytorch的数据集

Pytorch深度学习库以一种可读性强、模块化程度高的方式来构建深度学习网络。在构建深度学习网络时,数据的加载和预处理是一项重要而繁琐的工作。如果在构建网络中, 我们需要为加载样本数据、样本数据预处理编写大量的处理代码,会导致代码变得混乱、网络构建过程不清晰,最终难以维护。

基于以上考虑,Pytorch将数据集和数据集的加载定义为两个单独对象,使数据集代码和模型训练代码相分离,以获得更好的可读性和模块化。

Pytorch提供了两个DataSet和DataLoader两个类。

DataSet

DataSet是数据集对象类, Pytorch提供了大量的默认数据集, 包括Fashion-MINST、CIFAR-10、CIFAR-100、CelebA等数据集。如果用户想要加载自定义的数据只需要继承DataSet类。

Pytorch支持两种类型的DataSet:

  • Map类型DataSet
  • Iterable类型DataSet

Map类型DataSet

Map类型DataSet实现__getitem__()和 __len__(),表示从索引/键到数据样本的映射。数据集在使用 访问时,可以通过索引直接获取相关样本数据。例如,dataset[idx]表示使用idx从磁盘上的文件夹中读取第i个图像及其相应的标签。

Iterable类型DataSet

IterableDataset 实现了__iter__()函数,可对数据样本进行迭代访问。这种类型的数据集特别适用于随机读取代价高昂以及批量大小取决于获取的数据等场景。

例如,在从数据库、远程服务器甚至实时生成的日志中读取的数据流场景中,可以使用iter(dataset)来访问数据

代码示例

  1. import torch
  2. from torchvision import datasets
  3. training_data = datasets.FashionMNIST(
  4. root="data",
  5. train=True,
  6. download=True,
  7. transform=ToTensor()
  8. )
  9. test_data = datasets.FashionMNIST(
  10. root="data",
  11. train=False,
  12. download=True,
  13. transform=ToTensor()
  14. )

上面示例中,会下载Pytorch提供的数据集FashionMNIST。参数说明:

  • root:  数据集文件路径,DataSet会到该目录下查找相关数据集文件。
  • train: 是否是训练数据。 True:表示训练数据; False: 测试数据。如果为True,则下载训练数据集;如果为False,则下载测试数据集。
  • download: 当本地数据集不存在时, 是否远程下载到本地。
  • transform: 数据转换操作, ToTensor()将数据转换为张量。

DataLoader

Dataloader 是一个迭代器,最基本的使用就是传入一个 Dataset 对象,它就会根据参数 batch_size 的值生成一个 batch 的数据.

DataLoader函数

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

常用参数说明:

  • dataset: 需要加载的数据集
  • batch_size: 每个迭代返回的样本数
  • shuffle: 如果为True,则每次epoch时对数据进行shuffle操作。

示例代码

  1. from torch.utils.data import DataLoader
  2. train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
  3. test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
  4. train_features, train_labels = next(iter(train_dataloader))

上面代码中,DataLoader封装training_data、test_data为迭代器。 通过train_dataloader迭代获取样本数据,每个迭代获取64个(batch_size)个样本。

加载Fashion-MNIST数据集

下例为Pytorch关于DataSet和DataLoader的官方示例

该示例是从 TorchVision加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 的文章图像数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含一个 28×28 灰度图像和来自 10 个类别之一的相关标签。

我们使用以下参数加载FashionMNIST 数据集

  • root :存储训练/测试数据的路径,
  • train :指定训练或测试数据集,
  • download=True:如果数据不可用,则下载数据root
  • transform&target_transform:指定特征和标签转换

示例代码

  1. import torch
  2. from torch.utils.data import Dataset
  3. from torchvision import datasets
  4. from torch.utils.data import DataLoader
  5. from torchvision.transforms import ToTensor
  6. import matplotlib.pyplot as plt
  7. training_data = datasets.FashionMNIST(
  8. root="data",
  9. train=True,
  10. download=True,
  11. transform=ToTensor()
  12. )
  13. test_data = datasets.FashionMNIST(
  14. root="data",
  15. train=False,
  16. download=True,
  17. transform=ToTensor()
  18. )
  19. labels_map = {
  20. 0: "T-Shirt",
  21. 1: "Trouser",
  22. 2: "Pullover",
  23. 3: "Dress",
  24. 4: "Coat",
  25. 5: "Sandal",
  26. 6: "Shirt",
  27. 7: "Sneaker",
  28. 8: "Bag",
  29. 9: "Ankle Boot",
  30. }
  31. figure = plt.figure(figsize=(8, 8))
  32. cols, rows = 3, 3
  33. for i in range(1, cols * rows + 1):
  34. sample_idx = torch.randint(len(training_data), size=(1,)).item()
  35. img, label = training_data[sample_idx]
  36. figure.add_subplot(rows, cols, i)
  37. plt.title(labels_map[label])
  38. plt.axis("off")
  39. plt.imshow(img.squeeze(), cmap="gray")
  40. plt.show()
  41. train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
  42. test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
  43. train_features, train_labels = next(iter(train_dataloader))
  44. print(f"Feature batch shape: {train_features.size()}")
  45. print(f"Labels batch shape: {train_labels.size()}")
  46. img = train_features[0].squeeze()
  47. label = train_labels[0]
  48. plt.imshow(img, cmap="gray")
  49. plt.show()
  50. print(f"Label: {label}")

将该数据集加载到 中,DataLoader并且可以根据需要遍历数据集。每次迭代都会返回一批(64个样本)train_featurestrain_labelsbatch_size=64分别包含特征和标签)。因为指定了shuffle=True,在遍历所有批次后,数据会被打乱。

运行结果

../../_images/sphx_glr_data_tutorial_001.png

  1. Feature batch shape: torch.Size([64, 1, 28, 28])
  2. Labels batch shape: torch.Size([64])
  3. Label: 0

参考文档

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/318272
推荐阅读
相关标签
  

闽ICP备14008679号