当前位置:   article > 正文

Sampler类与4种采样方式_sequentialsampler

sequentialsampler

参考 Pytorch的Sampler详解 - 云+社区 - 腾讯云

参考Sampler类与4种采样方式 - 云+社区 - 腾讯云

由于我们不能将大量数据一次性放入网络中进行训练,所以需要分批进行数据读取。这一过程涉及到如何从数据集中读取数据的问题,pytorch提供了Sampler基类【1】与多个子类实现不同方式的数据采样。子类包含:

  • Sequential Sampler(顺序采样)
  • Random Sampler(随机采样)
  • Subset Random Sampler(子集随机采样)
  • Weighted Random Sampler(加权随机采样)等等。

1、基类Sampler

  1. class Sampler(object):
  2. r"""Base class for all Samplers.
  3. """
  4. def __init__(self, data_source):
  5. pass
  6. def __iter__(self):
  7. raise NotImplementedError

对于所有的采样器来说,都需要继承Sampler类,必须实现的方法为__iter__(),也就是定义迭代器行为,返回可迭代对象。除此之外,Sampler类并没有定义任何其它的方法。

2、顺序采样Sequential Sampler

  1. class SequentialSampler(Sampler):
  2. r"""Samples elements sequentially, always in the same order.
  3. Arguments:
  4. data_source (Dataset): dataset to sample from
  5. """
  6. def __init__(self, data_source):
  7. self.data_source = data_source
  8. def __iter__(self):
  9. return iter(range(len(self.data_source)))
  10. def __len__(self):
  11. return len(self.data_source)

顺序采样类并没有定义过多的方法,其中初始化方法仅仅需要一个Dataset类对象作为参数。对于__len__()只负责返回数据源包含的数据个数;__iter__(

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/700906
推荐阅读
相关标签
  

闽ICP备14008679号