赞
踩
最近在使用pytorch复现PointNet分割网络的过程中,在读入数据时遇到了一些问题,需要重写DataLoader中的sampler和collate_fn
sampler的作用是按照指定的顺序向batch里面读入数据,自定义的sampler可以根据我们的需要返回索引,DataLoader会根据我们返回的索引值提取数据,生成batch
注意:
重写sampler需要重写__len__()和__iter__()方法,其中__len__()返回你读入数据的总长度,iter()返回一个迭代器
例如,我们需要sampler根据样本点数返回索引
class sampler(data.Sampler): """ 由于每个batch的点数可能不一致 例如 len(b[0])=10220, len(b[1])=23300, len(b[2])=24000 , ... 该sampler是为了将每个batch内的点数统一 首先将batch里的样本按照点数从小到大排列 返回排序之后的索引值 """ def __init__(self, data_source): super(sampler, self).__init__(data_source) self.x = data_source # y = data_source[1] self.lst = [] for i in range(len(self.x)): self.lst.append(self.x[i].shape[0]) self.idx = np.argsort(self.lst) # 排序之后的索引 def __iter__(self): return iter(self.idx) # 这里的idx最后会返回给DataSets中的__getitem__方法 def __len__(self): return len(self.x[0]) # 这里的__len__需要返回总长度
在data.DataLoader中找到了下面注释:
sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be anyIterable
with__len__
implemented. If specified, :attr:shuffle
must not be specified.
意思是自己重写了sampler之后,shuffle关键字不能指定。
很重要的一点:
最后在传参的时候,传入的是sampler的示例
例如:
p_loader = data.DataLoader(p, batch_size, drop_last=True, sampler=sampler(sourcedata)) # p是DateSet的子类
collate_fn方法的作用是对于还未被连结batch进行操作,因为是没有被连结所以这里我说的batch是list类型
pytorch规定每一个batch中样本的点数必须相同,所以重写collate_fn方法,将每个batch中样本下采样到相同的数目
这里的函数的下采很简单,就是单纯的取得batch中的最小样本点数,将其他样本中的点shuffle之后取前最小个点数
def collate_fn(batch: list): """ DataLoader中的最后一步 对即将输出的batch进行操作 这里将每一个batch的样本点数降到最少点数即min_num_pts :param batch: B*N*3 不同B中的N并不相同 :return: """ batch_size = len(batch) ret_cls = [] for elem in batch: ret_cls.append(elem[2]) ret_cls = torch.tensor(ret_cls) num_pts_lst = [] for elem in batch: X, y = elem[0], elem[1] num_pts_lst.append(X.shape[0]) sorted_lst = np.argsort(num_pts_lst) min_mun_pts = len(batch[sorted_lst[0]][0]) temp_points = [] temp_target = [] for i in range(batch_size): X, y = batch[i][0], batch[i][1] torch.manual_seed(2021) X = X[torch.randperm(min_mun_pts)] torch.manual_seed(2021) y = y[torch.randperm(min_mun_pts)] temp_points.append(X) temp_target.append(y) ret_points = temp_points[0].unsqueeze(0) ret_target = temp_target[0].unsqueeze(0) for i in range(1, batch_size): cat_X = temp_points[i].unsqueeze(0) cat_y = temp_target[i].unsqueeze(0) ret_points = torch.cat([ret_points, cat_X], dim=0) ret_target = torch.cat([ret_target, cat_y], dim=0) return ret_points, ret_target, ret_cls
最后传入参数时,传入的是方法名
p_loader = data.DataLoader(p, batch_size, drop_last=True, collate_fn=collate_fn)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。