当前位置:   article > 正文

pytorch——详解DataLoader中的sampler和collate_fn_dataloader怎么重写sampler

dataloader怎么重写sampler

最近在使用pytorch复现PointNet分割网络的过程中,在读入数据时遇到了一些问题,需要重写DataLoader中的sampler和collate_fn

Sampler

sampler的作用是按照指定的顺序向batch里面读入数据,自定义的sampler可以根据我们的需要返回索引,DataLoader会根据我们返回的索引值提取数据,生成batch

注意:
重写sampler需要重写__len__()和__iter__()方法,其中__len__()返回你读入数据的总长度,iter()返回一个迭代器

例如,我们需要sampler根据样本点数返回索引

class sampler(data.Sampler):
    """
    由于每个batch的点数可能不一致
    例如 len(b[0])=10220, len(b[1])=23300, len(b[2])=24000 , ...
    该sampler是为了将每个batch内的点数统一
    首先将batch里的样本按照点数从小到大排列
    返回排序之后的索引值
    """

    def __init__(self, data_source):
        super(sampler, self).__init__(data_source)
        self.x = data_source
        # y = data_source[1]
        self.lst = []
        for i in range(len(self.x)):
            self.lst.append(self.x[i].shape[0])
        self.idx = np.argsort(self.lst)  # 排序之后的索引

    def __iter__(self):
        return iter(self.idx)  # 这里的idx最后会返回给DataSets中的__getitem__方法

    def __len__(self):
        return len(self.x[0])  # 这里的__len__需要返回总长度
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

在data.DataLoader中找到了下面注释:

sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any Iterable with __len__
implemented. If specified, :attr:shuffle must not be specified.

意思是自己重写了sampler之后,shuffle关键字不能指定。

很重要的一点:
最后在传参的时候,传入的是sampler的示例
例如:

p_loader = data.DataLoader(p, batch_size, drop_last=True, sampler=sampler(sourcedata))  # p是DateSet的子类
  • 1

collate_fn

collate_fn方法的作用是对于还未被连结batch进行操作,因为是没有被连结所以这里我说的batch是list类型
在这里插入图片描述

pytorch规定每一个batch中样本的点数必须相同,所以重写collate_fn方法,将每个batch中样本下采样到相同的数目
这里的函数的下采很简单,就是单纯的取得batch中的最小样本点数,将其他样本中的点shuffle之后取前最小个点数

def collate_fn(batch: list):
    """
    DataLoader中的最后一步
    对即将输出的batch进行操作
    这里将每一个batch的样本点数降到最少点数即min_num_pts
    :param batch: B*N*3 不同B中的N并不相同
    :return:
    """
    batch_size = len(batch)
    ret_cls = []
    for elem in batch:
        ret_cls.append(elem[2])
    ret_cls = torch.tensor(ret_cls)

    num_pts_lst = []
    for elem in batch:
        X, y = elem[0], elem[1]
        num_pts_lst.append(X.shape[0])
    sorted_lst = np.argsort(num_pts_lst)

    min_mun_pts = len(batch[sorted_lst[0]][0])

    temp_points = []
    temp_target = []
    for i in range(batch_size):
        X, y = batch[i][0], batch[i][1]
        torch.manual_seed(2021)
        X = X[torch.randperm(min_mun_pts)]
        torch.manual_seed(2021)
        y = y[torch.randperm(min_mun_pts)]
        temp_points.append(X)
        temp_target.append(y)
    ret_points = temp_points[0].unsqueeze(0)
    ret_target = temp_target[0].unsqueeze(0)

    for i in range(1, batch_size):
        cat_X = temp_points[i].unsqueeze(0)
        cat_y = temp_target[i].unsqueeze(0)
        ret_points = torch.cat([ret_points, cat_X], dim=0)
        ret_target = torch.cat([ret_target, cat_y], dim=0)
    return ret_points, ret_target, ret_cls

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

最后传入参数时,传入的是方法名

p_loader = data.DataLoader(p, batch_size, drop_last=True, collate_fn=collate_fn)
  • 1

参考:
知乎大佬的sampler源码解读,很有用

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/533987
推荐阅读
相关标签
  

闽ICP备14008679号