赞
踩
DataLoader源码
- class DataLoader(Generic[T_co]):
- r"""
- Data loader. Combines a dataset and a sampler, and provides an iterable over
- the given dataset.
- The :class:`~torch.utils.data.DataLoader` supports both map-style and
- iterable-style datasets with single- or multi-process loading, customizing
- loading order and optional automatic batching (collation) and memory pinning.
- See :py:mod:`torch.utils.data` documentation page for more details.
- Arguments:
- dataset (Dataset): dataset from which to load the data.
- batch_size (int, optional): how many samples per batch to load
- (default: ``1``).
- shuffle (bool, optional): set to ``True`` to have the data reshuffled
- at every epoch (default: ``False``).
- sampler (Sampler or Iterable, optional): defines the strategy to draw
- samples from the dataset. Can be any ``Iterable`` with ``__len__``
- implemented. If specified, :attr:`shuffle` must not be specified.
- batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
- returns a batch of indices at a time. Mutually exclusive with
- :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
- and :attr:`drop_last`.
- num_workers (int, optional): how many subprocesses to use for data
- loading. ``0`` means that the data will be loaded in the main process.
- (default: ``0``)
- collate_fn (callable, optional): merges a list of samples to form a
- mini-batch of Tensor(s). Used when using batched loading from a
- map-style dataset.
- pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
- into CUDA pinned memory before returning them. If your data elements
- are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
- see the example below.
- drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
- if the dataset size is not divisible by the batch size. If ``False`` and
- the size of dataset is not divisible by the batch size, then the last batch
- will be smaller. (default: ``False``)
- timeout (numeric, optional): if positive, the timeout value for collecting a batch
- from workers. Should always be non-negative. (default: ``0``)
- worker_init_fn (callable, optional): If not ``None``, this will be called on each
- worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
- input, after seeding and before data loading. (default: ``None``)
- prefetch_factor (int, optional, keyword-only arg): Number of sample loaded
- in advance by each worker. ``2`` means there will be a total of
- 2 * num_workers samples prefetched across all workers. (default: ``2``)
- persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
- the worker processes after a dataset has been consumed once. This allows to
- maintain the workers `Dataset` instances alive. (default: ``False``)
- .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
- cannot be an unpicklable object, e.g., a lambda function. See
- :ref:`multiprocessing-best-practices` on more details related
- to multiprocessing in PyTorch.
- .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
- When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
- it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
- rounding depending on :attr:`drop_last`, regardless of multi-process loading
- configurations. This represents the best guess PyTorch can make because PyTorch
- trusts user :attr:`dataset` code in correctly handling multi-process
- loading to avoid duplicate data.
- However, if sharding results in multiple workers having incomplete last batches,
- this estimate can still be inaccurate, because (1) an otherwise complete batch can
- be broken into multiple ones and (2) more than one batch worth of samples can be
- dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
- cases in general.
- See `Dataset Types`_ for more details on these two types of datasets and how
- :class:`~torch.utils.data.IterableDataset` interacts with
- `Multi-process data loading`_.
- """
- dataset: Dataset[T_co]
- batch_size: Optional[int]
- num_workers: int
- pin_memory: bool
- drop_last: bool
- timeout: float
- sampler: Sampler
- prefetch_factor: int
- _iterator : Optional['_BaseDataLoaderIter']
- __initialized = False
-
- def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
- shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
- batch_sampler: Optional[Sampler[Sequence[int]]] = None,
- num_workers: int = 0, collate_fn: _collate_fn_t = None,
- pin_memory: bool = False, drop_last: bool = False,
- timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
- multiprocessing_context=None, generator=None,
- *, prefetch_factor: int = 2,
- persistent_workers: bool = False):
- torch._C._log_api_usage_once("python.data_loader") # type: ignore
-
- if num_workers < 0:
- raise ValueError('num_workers option should be non-negative; '
- 'use num_workers=0 to disable multiprocessing.')
-
- if timeout < 0:
- raise ValueError('timeout option should be non-negative')
-
- if num_workers == 0 and prefetch_factor != 2:
- raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
- 'let num_workers > 0 to enable multiprocessing.')
- assert prefetch_factor > 0
-
- if persistent_workers and num_workers == 0:
- raise ValueError('persistent_workers option needs num_workers > 0')
-
- self.dataset = dataset
- self.num_workers = num_workers
- self.prefetch_factor = prefetch_factor
- self.pin_memory = pin_memory
- self.timeout = timeout
- self.worker_init_fn = worker_init_fn
- self.multiprocessing_context = multiprocessing_context
-
- # Arg-check dataset related before checking samplers because we want to
- # tell users that iterable-style datasets are incompatible with custom
- # samplers first, so that they don't learn that this combo doesn't work
- # after spending time fixing the custom sampler errors.
- if isinstance(dataset, IterableDataset):
- self._dataset_kind = _DatasetKind.Iterable
- # NOTE [ Custom Samplers and IterableDataset ]
- #
- # `IterableDataset` does not support custom `batch_sampler` or
- # `sampler` since the key is irrelevant (unless we support
- # generator-style dataset one day...).
- #
- # For `sampler`, we always create a dummy sampler. This is an
- # infinite sampler even when the dataset may have an implemented
- # finite `__len__` because in multi-process data loading, naive
- # settings will return duplicated data (which may be desired), and
- # thus using a sampler with length matching that of dataset will
- # cause data lost (you may have duplicates of the first couple
- # batches, but never see anything afterwards). Therefore,
- # `Iterabledataset` always uses an infinite sampler, an instance of
- # `_InfiniteConstantSampler` defined above.
- #
- # A custom `batch_sampler` essentially only controls the batch size.
- # However, it is unclear how useful it would be since an iterable-style
- # dataset can handle that within itself. Moreover, it is pointless
- # in multi-process data loading as the assignment order of batches
- # to workers is an implementation detail so users can not control
- # how to batchify each worker's iterable. Thus, we disable this
- # option. If this turns out to be useful in future, we can re-enable
- # this, and support custom samplers that specify the assignments to
- # specific workers.
- if shuffle is not False:
- raise ValueError(
- "DataLoader with IterableDataset: expected unspecified "
- "shuffle option, but got shuffle={}".format(shuffle))
- elif sampler is not None:
- # See NOTE [ Custom Samplers and IterableDataset ]
- raise ValueError(
- "DataLoader with IterableDataset: expected unspecified "
- "sampler option, but got sampler={}".format(sampler))
- elif batch_sampler is not None:
- # See NOTE [ Custom Samplers and IterableDataset ]
- raise ValueError(
- "DataLoader with IterableDataset: expected unspecified "
- "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
- else:
- self._dataset_kind = _DatasetKind.Map
-
- if sampler is not None and shuffle:
- raise ValueError('sampler option is mutually exclusive with '
- 'shuffle')
-
- if batch_sampler is not None:
- # auto_collation with custom batch_sampler
- if batch_size != 1 or shuffle or sampler is not None or drop_last:
- raise ValueError('batch_sampler option is mutually exclusive '
- 'with batch_size, shuffle, sampler, and '
- 'drop_last')
- batch_size = None
- drop_last = False
- elif batch_size is None:
- # no auto_collation
- if drop_last:
- raise ValueError('batch_size=None option disables auto-batching '
- 'and is mutually exclusive with drop_last')
-
- if sampler is None: # give default samplers
- if self._dataset_kind == _DatasetKind.Iterable:
- # See NOTE [ Custom Samplers and IterableDataset ]
- sampler = _InfiniteConstantSampler()
- else: # map-style
- if shuffle:
- # Cannot statically verify that dataset is Sized
- # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
- sampler = RandomSampler(dataset, generator=generator) # type: ignore
- else:
- sampler = SequentialSampler(dataset)
-
- if batch_size is not None and batch_sampler is None:
- # auto_collation without custom batch_sampler
- batch_sampler = BatchSampler(sampler, batch_size, drop_last)
-
- self.batch_size = batch_size
- self.drop_last = drop_last
- self.sampler = sampler
- self.batch_sampler = batch_sampler
- self.generator = generator
-
- if collate_fn is None:
- if self._auto_collation:
- collate_fn = _utils.collate.default_collate
- else:
- collate_fn = _utils.collate.default_convert
-
- self.collate_fn = collate_fn
- self.persistent_workers = persistent_workers
-
- self.__initialized = True
- self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
-
- self._iterator = None

源码传入参数主要如下所示:
- DataLoader(dataset,
- batch_size=1, # 每一批数据大小
- shuffle=False, #
- sampler=None,
- batch_sampler=None,
- num_workers=0,
- collate_fn=None,
- pin_memory=False,
- drop_last=False,
- timeout=0,
- worker_init_fn=None,
- multiprocessing_context=None)
-
- # 功能: 构建可迭代的数据装载器
-
- # dataset: Dataset类,决定数据从哪读取以及如何读取
- # batchsize: 批大小
- # num_works: 是否多进程读取数据
- # shuffle: 每个epoch是否乱序
- # drop_list: 当样本数不能被batchsize整除时,是否舍弃最后一批数据
-
-
-
- # Epoch: 所有训练样本都以输入到模型中,称为一个Epoch
- # Iteration: 一批样本输入到模型中,为一个Iteration
- # Batchsize: 批大小,主要是决定一个Epoch有多少个Iteration
-
- 样本81, Batchsize=8;
-
- 1 Epoch = 10 drop_last=True
- 1 Epoch = 11 drop_last=False
-

- Dataset
-
- torch.utils.data.Dataset
-
- 功能: Dataset抽象类,所有自定义的Dataset需要继承它,并且复写
-
- getitem: 接收一个索引,返回一个样本
-
-
- class Dataset(Generic[T_co]):
- r"""An abstract class representing a :class:`Dataset`.
- All datasets that represent a map from keys to data samples should subclass
- it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
- data sample for a given key. Subclasses could also optionally overwrite
- :meth:`__len__`, which is expected to return the size of the dataset by many
- :class:`~torch.utils.data.Sampler` implementations and the default options
- of :class:`~torch.utils.data.DataLoader`.
- .. note::
- :class:`~torch.utils.data.DataLoader` by default constructs a index
- sampler that yields integral indices. To make it work with a map-style
- dataset with non-integral indices/keys, a custom sampler must be provided.
- """
-
- def __getitem__(self, index) -> T_co:
- raise NotImplementedError
-
- def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
- return ConcatDataset([self, other])
-
-
- # 例子
- class Dataset(object):
-
- def __getitem__(self, index):
-
- path_img, label = self.data_info[index]
- img = Image.open(path_img).convert('RGB') # 0~255
-
- if self.transform is not None:
- img = self.transform(img)
-
- return img, label

1. 读那些数据 - Sampler输出的Index
2. 从哪读数据 - Dataset中的data_dir
3. 怎么读数据 - Dataset中的getitem
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。