赞
踩
参考 Pytorch的Sampler详解 - 云+社区 - 腾讯云
参考Sampler类与4种采样方式 - 云+社区 - 腾讯云
由于我们不能将大量数据一次性放入网络中进行训练,所以需要分批进行数据读取。这一过程涉及到如何从数据集中读取数据的问题,pytorch提供了Sampler基类【1】与多个子类实现不同方式的数据采样。子类包含:
- class Sampler(object):
- r"""Base class for all Samplers.
- """
- def __init__(self, data_source):
- pass
- def __iter__(self):
- raise NotImplementedError
对于所有的采样器来说,都需要继承Sampler类,必须实现的方法为__iter__(),也就是定义迭代器行为,返回可迭代对象。除此之外,Sampler类并没有定义任何其它的方法。
- class SequentialSampler(Sampler):
- r"""Samples elements sequentially, always in the same order.
- Arguments:
- data_source (Dataset): dataset to sample from
- """
- def __init__(self, data_source):
- self.data_source = data_source
- def __iter__(self):
- return iter(range(len(self.data_source)))
- def __len__(self):
- return len(self.data_source)
顺序采样类并没有定义过多的方法,其中初始化方法仅仅需要一个Dataset类对象作为参数。对于__len__()只负责返回数据源包含的数据个数;__iter__(
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。