赞
踩
#! https://zhuanlan.zhihu.com/p/543481537
本次笔记记录如何构建数据集,如何搭建minibatch,以及DataLoader源码剖析.
#使用Torchvision导入内置数据集 import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() )
继承抽象类Dataset,并自己实现以下三个函数:
__init__
__len__
__getitem__
import os import pandas as pd from torchvision.io import read_image class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])#每张图片路径 image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image)#预处理 if self.target_transform: label = self.target_transform(label)#预处理label return image, label
这种dataset叫做map-style datasets
读流式数据可以用iterable-style dataset
使用Minibatch的形式训练数据,在每个epoch都reshuffle数据从而降低模型过拟合的可能性,还通过python的multiprocessing多进程加载数据,使得读数据的过程不影响gpu训练,实现低延迟。(有的数据集读取的过程贼慢,比方说TextVQA这种,很影响GPU训练的效率)
DataLoader返回的是一个列表。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
#test不需要梯度下降,只要前向传递,不需要shuffle
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
展示一张图片与标签,
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()#灰度图片要变成三通道
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
官网的教程过于简单,下面深入剖析DataLoader,
包括sampler,collate_fn处理等,还有一些排序方法(比如让一个minibatch中样本长度差不多,这样在进行padding等操作的时候就会简单很多)。
主要解析三部分代码:
首先看dataloader
''' dataset: Dataset[T_co]:传入一个实例化的dataset对象 shuffle:每个epoch后打乱数据,一般在训练集上用 sampler/batch_sampler:自定义采样方式,显然与shuffle冲突 num_workers:进程数量,多进程读数据,0代表只用一个主进程 pin_memory:把tensor保存在gpu上,不用重复保存,对于效率影响有待考究 drop_last:样本数量不是batch_size整数倍时,把最后一部分数据丢掉 collate_fn:自定义聚集函数,对小批次数据进行再次处理,输入一个Batch,输出一个batch timeout:是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 ''' def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False): torch._C._log_api_usage_once("python.data_loader") if num_workers < 0: raise ValueError('num_workers option should be non-negative; ' 'use num_workers=0 to disable multiprocessing.') if timeout < 0: raise ValueError('timeout option should be non-negative') if num_workers == 0 and prefetch_factor != 2: raise ValueError('prefetch_factor option could only be specified in multiprocessing.' 'let num_workers > 0 to enable multiprocessing.') assert prefetch_factor > 0 if persistent_workers and num_workers == 0: raise ValueError('persistent_workers option needs num_workers > 0') #设置成员变量 self.dataset = dataset self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context #后续代码太长,这里不一一展示,后面不算很难
后续代码太长,这里不一一展示,init函数主要做三件事:
这里说几点重要的,方便放矢地看源码:
1.构建sampler
1.1.shuffle内部通过RandomSampler的具体实现,重点看__iter__.
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=generator).tolist()
1.2.shuffle = None通过SequentialSampler实现.
class SequentialSampler(Sampler[int]): r"""Samples elements sequentially, always in the same order. Args: data_source (Dataset): dataset to sample from """ data_source: Sized def __init__(self, data_source: Sized) -> None: self.data_source = data_source def __iter__(self) -> Iterator[int]: return iter(range(len(self.data_source))) def __len__(self) -> int: return len(self.data_source)
2.我们不使用batchsampler而仅仅是设定batch_size,内部其实还是通过BatchSampler实现的.
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
这里插入说一下python的迭代器和生成器,方便下面的理解
''' 迭代是Python最强大的功能之一,是访问集合元素的一种方式。 迭代器是一个可以记住遍历的位置的对象。 迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。 迭代器有两个基本的方法:iter() 和 next()。 字符串,列表或元组对象都可用于创建迭代器: ''' >>> list=[1,2,3,4] >>> it = iter(list) # 创建迭代器对象 >>> print (next(it)) # 输出迭代器的下一个元素 1 >>> print (next(it)) 2 #迭代器对象可以使用常规for语句进行遍历: list=[1,2,3,4] it = iter(list) # 创建迭代器对象 for x in it: print (x, end=" ") ''' 在 Python 中,使用了 yield 的函数被称为生成器(generator)。 跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。 在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。 调用一个生成器函数,返回的是一个迭代器对象。 以下实例使用 yield 实现斐波那契数列: ''' import sys def fibonacci(n): # 生成器函数 - 斐波那契 a, b, counter = 0, 1, 0 while True: if (counter > n): return yield a a, b = b, a + b counter += 1 f = fibonacci(10) # f 是一个迭代器,由生成器返回生成 while True: try: print (next(f), end=" ") except StopIteration: sys.exit()
回归BatchSampler实现:
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self.sampler:
batch.append(idx)
f len(batch) == self.batch_size:
yield batch #生成器,每次next返回一个batch
batch = []
if len(batch) > 0 and not self.drop_last:#drop_last的原理
yield batch
3.collate_fn自定义聚集函数
if collate_fn is None:
if self._auto_collation: #batch_sample is not None ->true
collate_fn = _utils.collate.default_collate #输入输出都是batch,做了一些数据类型的转换
else:
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
4.迭代器iter()可以把dataloader变成一个迭代器
iter
-> _get_iterator()
_get_iterator()
-> _SingleProcessDataLoaderIter()
or _MultiProcessingDataLoaderIter()
_SingleProcessDataLoaderIter()
继承自BaseDataLoaderIter
.
我们发现_SingleProcessDataLoaderIter()
的_next_data
函数没有在子类中被调用,猜测可以在父类BaseDataLoaderIter
找到调用.
经查看,父类__next__
中确实调用了_next_data
,这就可以解释了代码的内部原理。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。