赞
踩
PyTorch提供的两个常用数据API:
官方案例: Fashion-MNIST数据集
torchvision:torch的一个视觉库,将torchvision中的datasets导入进来,就能获得其中的各种数据集
FashionMNIST图像存储在目录img_dir中,标签存储在CSV文件annotations_file中
import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() )
对上述数据集进行可视化:
labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } figure = plt.figure(figsize=(8, 8)) cols, rows = 3, 3 for i in range(1, cols * rows + 1): sample_idx = torch.randint(len(training_data), size=(1,)).item() img, label = training_data[sample_idx] figure.add_subplot(rows, cols, i) plt.title(labels_map[label]) plt.axis("off") plt.imshow(img.squeeze(), cmap="gray") plt.show()
pytorch中的dataset类是在pytorch的torch下的utils之下的data文件夹里有一个dataset.py
包含图像、注释文件和两个转换:
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file) #标签存储在CSV文件annotations_file中
self.img_dir = img_dir #FashionMNIST图像存储在目录img_dir中
self.transform = transform #图像转换
self.target_transform = target_transform
返回数据集的样本数(就是img_labels的长度)
def __len__(self):
return len(self.img_labels)
输入索引index,getitem函数从数据集中加载并返回对应index的一个样本:
def __getitem__(self, idx):
#img_labels的第index行第0列标注了对应的照片文件名称
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path) #使用read_image将图像转换为张量
label = self.img_labels.iloc[idx, 1] #从self中的csv数据中检索相应的标签
#调用转换函数
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label #返回张量图像和相应的标签
import os import pandas as pd from torchvision.io import read_image class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label
DataLoader通常是在torch.utils.data下
常用的参数有:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
主要做了三件事:构建sampler、构建batch_sampler、构建collate_fn
定义属性:
如果设置了自定义的sampler然后又设置了shuffle=true,这种情况是没有意义的:
(shuffle是官方自定义的一个随机sampler)
设置了batch_sampler的情况下,就不需要设置batch_size、shuffle、sampler和drop_last了:
如果没有设置sampler,则先判断数据集类型,如果使用的是map-style(else逻辑),就根据是否设置shuffle来选择pytorch内置的sampler:
设置了batch_size但是没有设置batch_sampler时,会使用内置的BatchSampler:
如果没有设置collate_fn,就判断auto_collation是否设置(auto_collation是根据batch_sampler是否是None来设置的,如果batch_sampler不是none,auto_collation就是true),default_collate是将batch作为输入,batch输出,并没有对数据做额外处理:
iter函数返回的是get_iterator的值:
get_iterator根据num_workers的设置选择对应的内置DataLoaderIter:
所以可知,iter函数最终返回的是一个dataloaderiter对象,以SingleProcessDataLoaderIter为例,类里有next_data函数:
SingleProcessDataLoaderIter类是继承了BaseDataLoaderIter类,BaseDataLoaderIter类中的next函数就是使用了子类中的next_data:
根据上述源码分析,就可以对dataloader去迭代iter之后调用next函数来获得每一批次的数据:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
由于batch_size=64,因此最终返回的Feature batch shape以及Labels batch shape均为64。
参考:
PyTorch官方文档:Datasets & DataLoaders
5、深入剖析PyTorch DataLoader源码
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。