赞
踩
1.Dataset
pytorch中提供了两种Dataset,一种是Dataset,另一种是IterableDataset。
在构建Dataset子类的时候,一般来说只需要定义__init__、__get_item__和__len__这3个方法,它们的作用分别如下:
__init__:初始化类
__get_item__:提取Dataset中的元素,通常是元组形式,如(input,target)
__len__:在对Dataset取len时,返回Dataset中的元素个数
IterableDataset是一个迭代器,需要重写__iter__方法,通过__iter__方法获得下一条数据。(这个目前没有遇到,待深入研究)
2.DataLoader
DataLoader提供了将数据整合成一个个批次的方法,用于进行模型批量运算。DataLoader中有如下几个需要注意的参数:
batch_size:一个批次数据中的样本数量
shuffle:打扰数据,避免模型陷入局部最优的情况,在定义了sampler之后,这个参数就无法使用了
sampler:采样器,如果有特殊的数据整合需求,可以自定义一个sampler,在sampler中返回每个批次的数据下标列表
pin_memory:将数据传入CUDA的Pinned Memory,方便更快的传入GPU中
collate_fn:进一步处理打包sampler筛选出来的一组组数据
num_workers:采用多进程方式加载,如果CPU能力较强,可以选择这种方法
drop_last:在样本总数不能被批次大小整除的情况下,最后一个批次的样本数量可能会与前面的批次不一致,若模型要求每个批次样本数量一致,可以将drop_last设置为True
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。