赞
踩
在之前沐神的Cifar-10分类 课程学习中,沐神是用的将每一类创建一个文件夹去完成图片的导入。此外我们还可以通过重写DataSet类来完成!
通过查看官方文档我们可知。
需要去重写__getitem__这个方法,去以一种特定的方法拿到一个数据。并且选择性的重写__len__这个方法,去返回整个数据集的大小。
这个数据集是沐神课程上讲过的cifar-10数据集。
train和test文件夹分别为要进行训练和测试的图片。而训练数据的标签以csv文件存在trainLabels.csv文件中。
def read_csv_labels(fname):
with open(fname,'r') as f:
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
return dict(((name,label) for name,label in tokens))
这里通过一个read_csv_labels的方法 将图片名字和标签以一个字典的方式返回
class MyDateset(Dataset): def __init__(self,root_dir,state,label_dict=None): self.root_dir = root_dir self.state = state if label_dict is not None: self.label_dict = label_dict self.img_path = os.listdir(os.path.join(root_dir,state)) # os.listdir 将当前文件夹下的图片名称按列表返回 def __getitem__(self, idx): img = Image.open(os.path.join(self.root_dir,self.state,self.img_path[idx])) if self.state == 'train': img_num =self.img_path[idx].split('.')[0] # 这个取出来是数字.jpg 所以需要将.jpg舍去 label = self.label_dict[img_num] return img,label else: return img def __len__(self): return len(self.img_path)
state参数表示此时是训练数据集还是测试数据集。
root_dir = "D:\\PytorchLearn\\cifar-10"
label_dict = read_csv_labels(os.path.join(root_dir,"trainLabels.csv"))
train_dataset = MyDateset(root_dir,'train',label_dict)
test_dataset = MyDateset(root_dir,'test')
train_iter = torch.utils.data.DataLoader(train_dataset,batch_size=8,shuffle=True)
以上就是重写DataSet的方法,有不足之处还望各位指出。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。