赞
踩
目录
Pytorch深度学习库以一种可读性强、模块化程度高的方式来构建深度学习网络。在构建深度学习网络时,数据的加载和预处理是一项重要而繁琐的工作。如果在构建网络中, 我们需要为加载样本数据、样本数据预处理编写大量的处理代码,会导致代码变得混乱、网络构建过程不清晰,最终难以维护。
基于以上考虑,Pytorch将数据集和数据集的加载定义为两个单独对象,使数据集代码和模型训练代码相分离,以获得更好的可读性和模块化。
Pytorch提供了两个DataSet和DataLoader两个类。
DataSet是数据集对象类, Pytorch提供了大量的默认数据集, 包括Fashion-MINST、CIFAR-10、CIFAR-100、CelebA等数据集。如果用户想要加载自定义的数据只需要继承DataSet类。
Pytorch支持两种类型的DataSet:
Map类型DataSet实现__getitem__()
和 __len__()
,表示从索引/键到数据样本的映射。数据集在使用 访问时,可以通过索引直接获取相关样本数据。例如,dataset[idx]表示使用
idx
从磁盘上的文件夹中读取第i个图像及其相应的标签。
IterableDataset
实现了__iter__()函数
,可对数据样本进行迭代访问。这种类型的数据集特别适用于随机读取代价高昂以及批量大小取决于获取的数据等场景。
例如,在从数据库、远程服务器甚至实时生成的日志中读取的数据流场景中,可以使用iter(dataset)
来访问数据
- import torch
- from torchvision import datasets
-
- training_data = datasets.FashionMNIST(
- root="data",
- train=True,
- download=True,
- transform=ToTensor()
- )
-
- test_data = datasets.FashionMNIST(
- root="data",
- train=False,
- download=True,
- transform=ToTensor()
- )
上面示例中,会下载Pytorch提供的数据集FashionMNIST。参数说明:
Dataloader 是一个迭代器,最基本的使用就是传入一个 Dataset 对象,它就会根据参数 batch_size 的值生成一个 batch 的数据.
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
常用参数说明:
示例代码
- from torch.utils.data import DataLoader
-
- train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
- test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
-
- train_features, train_labels = next(iter(train_dataloader))
上面代码中,DataLoader封装training_data、test_data为迭代器。 通过train_dataloader迭代获取样本数据,每个迭代获取64个(batch_size)个样本。
下例为Pytorch关于DataSet和DataLoader的官方示例。
该示例是从 TorchVision加载Fashion-MNIST数据集的示例。Fashion-MNIST 是 Zalando 的文章图像数据集,由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含一个 28×28 灰度图像和来自 10 个类别之一的相关标签。
我们使用以下参数加载FashionMNIST 数据集:
root
:存储训练/测试数据的路径,train
:指定训练或测试数据集,download=True:
如果数据不可用,则下载数据root
。transform&
target_transform:
指定特征和标签转换- import torch
- from torch.utils.data import Dataset
- from torchvision import datasets
- from torch.utils.data import DataLoader
- from torchvision.transforms import ToTensor
- import matplotlib.pyplot as plt
-
-
- training_data = datasets.FashionMNIST(
- root="data",
- train=True,
- download=True,
- transform=ToTensor()
- )
-
- test_data = datasets.FashionMNIST(
- root="data",
- train=False,
- download=True,
- transform=ToTensor()
- )
-
- labels_map = {
- 0: "T-Shirt",
- 1: "Trouser",
- 2: "Pullover",
- 3: "Dress",
- 4: "Coat",
- 5: "Sandal",
- 6: "Shirt",
- 7: "Sneaker",
- 8: "Bag",
- 9: "Ankle Boot",
- }
- figure = plt.figure(figsize=(8, 8))
- cols, rows = 3, 3
- for i in range(1, cols * rows + 1):
- sample_idx = torch.randint(len(training_data), size=(1,)).item()
- img, label = training_data[sample_idx]
- figure.add_subplot(rows, cols, i)
- plt.title(labels_map[label])
- plt.axis("off")
- plt.imshow(img.squeeze(), cmap="gray")
- plt.show()
-
-
- train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
- test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
-
- train_features, train_labels = next(iter(train_dataloader))
- print(f"Feature batch shape: {train_features.size()}")
- print(f"Labels batch shape: {train_labels.size()}")
- img = train_features[0].squeeze()
- label = train_labels[0]
- plt.imshow(img, cmap="gray")
- plt.show()
- print(f"Label: {label}")
将该数据集加载到 中,DataLoader
并且可以根据需要遍历数据集。每次迭代都会返回一批(64个样本)train_features
和train_labels
(batch_size=64
分别包含特征和标签)。因为指定了shuffle=True
,在遍历所有批次后,数据会被打乱。
Feature batch shape: torch.Size([64, 1, 28, 28]) Labels batch shape: torch.Size([64]) Label: 0
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。