当前位置:   article > 正文

DataLoader 与 Dataset_prefetch_factor option could only be specified in

prefetch_factor option could only be specified in multiprocessing.let num_wo

一、总体概览

 

二、具体详解

DataLoader源码

  1. class DataLoader(Generic[T_co]):
  2. r"""
  3. Data loader. Combines a dataset and a sampler, and provides an iterable over
  4. the given dataset.
  5. The :class:`~torch.utils.data.DataLoader` supports both map-style and
  6. iterable-style datasets with single- or multi-process loading, customizing
  7. loading order and optional automatic batching (collation) and memory pinning.
  8. See :py:mod:`torch.utils.data` documentation page for more details.
  9. Arguments:
  10. dataset (Dataset): dataset from which to load the data.
  11. batch_size (int, optional): how many samples per batch to load
  12. (default: ``1``).
  13. shuffle (bool, optional): set to ``True`` to have the data reshuffled
  14. at every epoch (default: ``False``).
  15. sampler (Sampler or Iterable, optional): defines the strategy to draw
  16. samples from the dataset. Can be any ``Iterable`` with ``__len__``
  17. implemented. If specified, :attr:`shuffle` must not be specified.
  18. batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
  19. returns a batch of indices at a time. Mutually exclusive with
  20. :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
  21. and :attr:`drop_last`.
  22. num_workers (int, optional): how many subprocesses to use for data
  23. loading. ``0`` means that the data will be loaded in the main process.
  24. (default: ``0``)
  25. collate_fn (callable, optional): merges a list of samples to form a
  26. mini-batch of Tensor(s). Used when using batched loading from a
  27. map-style dataset.
  28. pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
  29. into CUDA pinned memory before returning them. If your data elements
  30. are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
  31. see the example below.
  32. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
  33. if the dataset size is not divisible by the batch size. If ``False`` and
  34. the size of dataset is not divisible by the batch size, then the last batch
  35. will be smaller. (default: ``False``)
  36. timeout (numeric, optional): if positive, the timeout value for collecting a batch
  37. from workers. Should always be non-negative. (default: ``0``)
  38. worker_init_fn (callable, optional): If not ``None``, this will be called on each
  39. worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
  40. input, after seeding and before data loading. (default: ``None``)
  41. prefetch_factor (int, optional, keyword-only arg): Number of sample loaded
  42. in advance by each worker. ``2`` means there will be a total of
  43. 2 * num_workers samples prefetched across all workers. (default: ``2``)
  44. persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
  45. the worker processes after a dataset has been consumed once. This allows to
  46. maintain the workers `Dataset` instances alive. (default: ``False``)
  47. .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
  48. cannot be an unpicklable object, e.g., a lambda function. See
  49. :ref:`multiprocessing-best-practices` on more details related
  50. to multiprocessing in PyTorch.
  51. .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
  52. When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
  53. it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
  54. rounding depending on :attr:`drop_last`, regardless of multi-process loading
  55. configurations. This represents the best guess PyTorch can make because PyTorch
  56. trusts user :attr:`dataset` code in correctly handling multi-process
  57. loading to avoid duplicate data.
  58. However, if sharding results in multiple workers having incomplete last batches,
  59. this estimate can still be inaccurate, because (1) an otherwise complete batch can
  60. be broken into multiple ones and (2) more than one batch worth of samples can be
  61. dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
  62. cases in general.
  63. See `Dataset Types`_ for more details on these two types of datasets and how
  64. :class:`~torch.utils.data.IterableDataset` interacts with
  65. `Multi-process data loading`_.
  66. """
  67. dataset: Dataset[T_co]
  68. batch_size: Optional[int]
  69. num_workers: int
  70. pin_memory: bool
  71. drop_last: bool
  72. timeout: float
  73. sampler: Sampler
  74. prefetch_factor: int
  75. _iterator : Optional['_BaseDataLoaderIter']
  76. __initialized = False
  77. def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
  78. shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
  79. batch_sampler: Optional[Sampler[Sequence[int]]] = None,
  80. num_workers: int = 0, collate_fn: _collate_fn_t = None,
  81. pin_memory: bool = False, drop_last: bool = False,
  82. timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
  83. multiprocessing_context=None, generator=None,
  84. *, prefetch_factor: int = 2,
  85. persistent_workers: bool = False):
  86. torch._C._log_api_usage_once("python.data_loader") # type: ignore
  87. if num_workers < 0:
  88. raise ValueError('num_workers option should be non-negative; '
  89. 'use num_workers=0 to disable multiprocessing.')
  90. if timeout < 0:
  91. raise ValueError('timeout option should be non-negative')
  92. if num_workers == 0 and prefetch_factor != 2:
  93. raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
  94. 'let num_workers > 0 to enable multiprocessing.')
  95. assert prefetch_factor > 0
  96. if persistent_workers and num_workers == 0:
  97. raise ValueError('persistent_workers option needs num_workers > 0')
  98. self.dataset = dataset
  99. self.num_workers = num_workers
  100. self.prefetch_factor = prefetch_factor
  101. self.pin_memory = pin_memory
  102. self.timeout = timeout
  103. self.worker_init_fn = worker_init_fn
  104. self.multiprocessing_context = multiprocessing_context
  105. # Arg-check dataset related before checking samplers because we want to
  106. # tell users that iterable-style datasets are incompatible with custom
  107. # samplers first, so that they don't learn that this combo doesn't work
  108. # after spending time fixing the custom sampler errors.
  109. if isinstance(dataset, IterableDataset):
  110. self._dataset_kind = _DatasetKind.Iterable
  111. # NOTE [ Custom Samplers and IterableDataset ]
  112. #
  113. # `IterableDataset` does not support custom `batch_sampler` or
  114. # `sampler` since the key is irrelevant (unless we support
  115. # generator-style dataset one day...).
  116. #
  117. # For `sampler`, we always create a dummy sampler. This is an
  118. # infinite sampler even when the dataset may have an implemented
  119. # finite `__len__` because in multi-process data loading, naive
  120. # settings will return duplicated data (which may be desired), and
  121. # thus using a sampler with length matching that of dataset will
  122. # cause data lost (you may have duplicates of the first couple
  123. # batches, but never see anything afterwards). Therefore,
  124. # `Iterabledataset` always uses an infinite sampler, an instance of
  125. # `_InfiniteConstantSampler` defined above.
  126. #
  127. # A custom `batch_sampler` essentially only controls the batch size.
  128. # However, it is unclear how useful it would be since an iterable-style
  129. # dataset can handle that within itself. Moreover, it is pointless
  130. # in multi-process data loading as the assignment order of batches
  131. # to workers is an implementation detail so users can not control
  132. # how to batchify each worker's iterable. Thus, we disable this
  133. # option. If this turns out to be useful in future, we can re-enable
  134. # this, and support custom samplers that specify the assignments to
  135. # specific workers.
  136. if shuffle is not False:
  137. raise ValueError(
  138. "DataLoader with IterableDataset: expected unspecified "
  139. "shuffle option, but got shuffle={}".format(shuffle))
  140. elif sampler is not None:
  141. # See NOTE [ Custom Samplers and IterableDataset ]
  142. raise ValueError(
  143. "DataLoader with IterableDataset: expected unspecified "
  144. "sampler option, but got sampler={}".format(sampler))
  145. elif batch_sampler is not None:
  146. # See NOTE [ Custom Samplers and IterableDataset ]
  147. raise ValueError(
  148. "DataLoader with IterableDataset: expected unspecified "
  149. "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
  150. else:
  151. self._dataset_kind = _DatasetKind.Map
  152. if sampler is not None and shuffle:
  153. raise ValueError('sampler option is mutually exclusive with '
  154. 'shuffle')
  155. if batch_sampler is not None:
  156. # auto_collation with custom batch_sampler
  157. if batch_size != 1 or shuffle or sampler is not None or drop_last:
  158. raise ValueError('batch_sampler option is mutually exclusive '
  159. 'with batch_size, shuffle, sampler, and '
  160. 'drop_last')
  161. batch_size = None
  162. drop_last = False
  163. elif batch_size is None:
  164. # no auto_collation
  165. if drop_last:
  166. raise ValueError('batch_size=None option disables auto-batching '
  167. 'and is mutually exclusive with drop_last')
  168. if sampler is None: # give default samplers
  169. if self._dataset_kind == _DatasetKind.Iterable:
  170. # See NOTE [ Custom Samplers and IterableDataset ]
  171. sampler = _InfiniteConstantSampler()
  172. else: # map-style
  173. if shuffle:
  174. # Cannot statically verify that dataset is Sized
  175. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  176. sampler = RandomSampler(dataset, generator=generator) # type: ignore
  177. else:
  178. sampler = SequentialSampler(dataset)
  179. if batch_size is not None and batch_sampler is None:
  180. # auto_collation without custom batch_sampler
  181. batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  182. self.batch_size = batch_size
  183. self.drop_last = drop_last
  184. self.sampler = sampler
  185. self.batch_sampler = batch_sampler
  186. self.generator = generator
  187. if collate_fn is None:
  188. if self._auto_collation:
  189. collate_fn = _utils.collate.default_collate
  190. else:
  191. collate_fn = _utils.collate.default_convert
  192. self.collate_fn = collate_fn
  193. self.persistent_workers = persistent_workers
  194. self.__initialized = True
  195. self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
  196. self._iterator = None

源码传入参数主要如下所示:

  1. DataLoader(dataset,
  2. batch_size=1, # 每一批数据大小
  3. shuffle=False, #
  4. sampler=None,
  5. batch_sampler=None,
  6. num_workers=0,
  7. collate_fn=None,
  8. pin_memory=False,
  9. drop_last=False,
  10. timeout=0,
  11. worker_init_fn=None,
  12. multiprocessing_context=None)
  13. # 功能: 构建可迭代的数据装载器
  14. # dataset: Dataset类,决定数据从哪读取以及如何读取
  15. # batchsize: 批大小
  16. # num_works: 是否多进程读取数据
  17. # shuffle: 每个epoch是否乱序
  18. # drop_list: 当样本数不能被batchsize整除时,是否舍弃最后一批数据
  19. # Epoch: 所有训练样本都以输入到模型中,称为一个Epoch
  20. # Iteration: 一批样本输入到模型中,为一个Iteration
  21. # Batchsize: 批大小,主要是决定一个Epoch有多少个Iteration
  22. 样本81, Batchsize=8;
  23. 1 Epoch = 10 drop_last=True
  24. 1 Epoch = 11 drop_last=False
  1. Dataset
  2. torch.utils.data.Dataset
  3. 功能: Dataset抽象类,所有自定义的Dataset需要继承它,并且复写
  4. getitem: 接收一个索引,返回一个样本
  5. class Dataset(Generic[T_co]):
  6. r"""An abstract class representing a :class:`Dataset`.
  7. All datasets that represent a map from keys to data samples should subclass
  8. it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
  9. data sample for a given key. Subclasses could also optionally overwrite
  10. :meth:`__len__`, which is expected to return the size of the dataset by many
  11. :class:`~torch.utils.data.Sampler` implementations and the default options
  12. of :class:`~torch.utils.data.DataLoader`.
  13. .. note::
  14. :class:`~torch.utils.data.DataLoader` by default constructs a index
  15. sampler that yields integral indices. To make it work with a map-style
  16. dataset with non-integral indices/keys, a custom sampler must be provided.
  17. """
  18. def __getitem__(self, index) -> T_co:
  19. raise NotImplementedError
  20. def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
  21. return ConcatDataset([self, other])
  22. # 例子
  23. class Dataset(object):
  24. def __getitem__(self, index):
  25. path_img, label = self.data_info[index]
  26. img = Image.open(path_img).convert('RGB') # 0~255
  27. if self.transform is not None:
  28. img = self.transform(img)
  29. return img, label

1. 读那些数据 - Sampler输出的Index

2. 从哪读数据 - Dataset中的data_dir

3. 怎么读数据 - Dataset中的getitem

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/318292
推荐阅读
相关标签
  

闽ICP备14008679号