赞
踩
课程地址
最近做实验发现自己还是基础框架上掌握得不好,于是开始重学一遍PyTorch框架,这个是课程笔记,此节课很详细,笔记记的比较粗
构造函数有如下参数:
【注】在深度学习和自然语言处理(NLP)等领域中,pad(填充)是一个常见的预处理步骤,特别是在处理变长序列(如文本、时间序列等)时。当使用DataLoader从数据集中批量提取数据时,如果每个数据项(例如,句子或时间序列)的长度不同,那么为了能够在同一批次中进行高效计算(例如,通过矩阵运算),我们通常需要将这些数据项填充(或截断)到相同的长度。
这就是collate_fn参数发挥作用的地方。默认情况下,DataLoader使用了一个内置的collate_fn来将一批数据项组合成一个张量(tensor),但这个默认函数并不进行填充。为了进行填充,你需要提供一个自定义的collate_fn。
构造函数的具体代码和注释如下:
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None, batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False): torch._C._log_api_usage_once("python.data_loader") if num_workers < 0: raise ValueError('num_workers option should be non-negative; ' 'use num_workers=0 to disable multiprocessing.') if timeout < 0: raise ValueError('timeout option should be non-negative') if num_workers == 0 and prefetch_factor != 2: raise ValueError('prefetch_factor option could only be specified in multiprocessing.' 'let num_workers > 0 to enable multiprocessing.') assert prefetch_factor > 0 if persistent_workers and num_workers == 0: raise ValueError('persistent_workers option needs num_workers > 0') # 设置成员函数 self.dataset = dataset self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context # 这里不用看,一般我们都是用Dataset类,而不是IterableDataset,所以直接看这个if条件后面对应的else条件 if isinstance(dataset, IterableDataset): self._dataset_kind = _DatasetKind.Iterable if isinstance(dataset, IterDataPipe): torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) elif shuffle is not False: raise ValueError( "DataLoader with IterableDataset: expected unspecified " "shuffle option, but got shuffle={}".format(shuffle)) if sampler is not None: # See NOTE [ Custom Samplers and IterableDataset ] raise ValueError( "DataLoader with IterableDataset: expected unspecified " "sampler option, but got sampler={}".format(sampler)) elif batch_sampler is not None: # See NOTE [ Custom Samplers and IterableDataset ] raise ValueError( "DataLoader with IterableDataset: expected unspecified " "batch_sampler option, but got batch_sampler={}".format(batch_sampler)) # 直接跳到else条件 else: # 设置数据集的种类是DatasetKind.Map类型 self._dataset_kind = _DatasetKind.Map # 如果你设置了sampler(默认为None),如果你传入了自定义的sampler且shuffle设置为True的话,这种情况是没有意义的,shuffle是官方提供的一种随机采用党的sampler,你都自定义sampler了,就不需要shuffle来随机打乱。所以shuffle和sampler是互斥的,不能同时去设置 if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle') # batch_sampler是批次级别的采样,sampler是样本级的采样, if batch_sampler is not None: # 如果你设置了batch_size不是1,或者你设置了shuffle或者你设置了sampler,或者你设置了drop_last,这些都与batch_sampler是互斥的,总结一句话就是:你只要设置了batch_sampler就不需要设置batch_size了,因为你设置了batch_sampler就已经告诉PyTorch框架你的batch_size和以什么样的方式去构成mini-batch if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') batch_size = None drop_last = False # 如果batch_size是None,同时如果有drop_last,这时候会报错 elif batch_size is None: # no auto_collation if drop_last: raise ValueError('batch_size=None option disables auto-batching ' 'and is mutually exclusive with drop_last') # 如果你没有设置sampler的话 if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style(常用的),如果你设置了shuffle的话,它就会用内置的一个叫random sample的类来去对我们这个Dataset进行一个随机的打乱。具体实现在下面的章节 if shuffle: sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] # 如果没有设置shuffle为True的话,它就用SequentialSampler即按原本的顺序来采样 else: sampler = SequentialSampler(dataset) # type: ignore[arg-type] # 如果你的batch_size不是None并且batch_sampler也不是None # 它就默认给你构造一个batch_sampler # BatchSampler源码实现见下面的章节 if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler self.batch_sampler = batch_sampler self.generator = generator # 如果collate_fn参数为None,则如果设置了auto_collatoion,就调用默认的default_collate if collate_fn is None: # _auto_collation是根据batch_sampler是否为None来去设置的,如果batch_sampler不是None,_auto_collation设置为True,如果batch_sampler是None的话,它就会调用_utils.collate.default_convert这个函数,否则调用_utils.collate.default_collate函数。 # _utils.collate.default_collate函数是以batch作为输入,它相当于什么都没做,最后返回了个batch,如果自己要实现这个collate_fn,要以batch做输入,然后再做处理。 if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn self.persistent_workers = persistent_workers self.__initialized = True self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ] self._iterator = None self.check_worker_number_rationality() torch.set_vital('Dataloader', 'enabled', 'True') # type: ignore[attr-defined]
def _get_iterator(self) -> '_BaseDataLoaderIter':
# 如果设置num_workers为0的话,它就走单个样本处理过程
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
# 如果num_workers不为0,说明是多进程读取样本
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
一般迭代用,是在__iter__方法中实现的,使得DataLoader能变成一个可迭代的对象。
重点看中文注释
class RandomSampler(Sampler[int]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Args: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. generator (Generator): Generator used in sampling. """ data_source: Sized replacement: bool def __init__(self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None) -> None: self.data_source = data_source self.replacement = replacement self._num_samples = num_samples self.generator = generator if not isinstance(self.replacement, bool): raise TypeError("replacement should be a boolean value, but got " "replacement={}".format(self.replacement)) if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(self.num_samples)) @property def num_samples(self) -> int: # dataset size might change at runtime if self._num_samples is None: return len(self.data_source) return self._num_samples # 首先看__iter__方法 def __iter__(self) -> Iterator[int]: # 获取数据集的大小 n = len(self.data_source) # 如果没有传入generator的话,他就会随机生成一个种子,去构建一个生成器generator if self.generator is None: # 设置随机数的种子 seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator() generator.manual_seed(seed) else: generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() # 返回0到n-1的列表的随机组合,n是数据集长度 yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: for _ in range(self.num_samples // n): yield from torch.randperm(n, generator=generator).tolist() yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] def __len__(self) -> int: return self.num_samples
class SequentialSampler(Sampler[int]): r"""Samples elements sequentially, always in the same order. Args: data_source (Dataset): dataset to sample from """ data_source: Sized def __init__(self, data_source: Sized) -> None: self.data_source = data_source # 如果迭代它,返回的是有序的索引 def __iter__(self) -> Iterator[int]: return iter(range(len(self.data_source))) def __len__(self) -> int: return len(self.data_source)
也是直接看__iter__函数
class BatchSampler(Sampler[List[int]]): def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None: # Since collections.abc.Iterable does not check for `__getitem__`, which # is one way for an object to be an iterable, we don't do an `isinstance` # check here. if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last # 先看iter函数 def __iter__(self) -> Iterator[List[int]]: # 先创建一个空列表batch batch = [] # 对sampler进行一个迭代,去元素的索引 for idx in self.sampler: # 将其索引添加到列表中 batch.append(idx) # 如果列表长度等于batch_size,这时候就返回列表,相当于返回一个批次batch,然后把batch置为空 if len(batch) == self.batch_size: yield batch batch = [] # 如果drop_last(是否丢弃最后的不够一个批次数量的元素)设置为False,那我们就把最后这个不够数量的批次也返回 if len(batch) > 0 and not self.drop_last: yield batch def __len__(self) -> int: # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] if self.drop_last: return len(self.sampler) // self.batch_size # type: ignore[arg-type] else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
这个UP讲的太详细了,没全记录,部分细节可以看看视频
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。