赞
踩
参考链接:https://blog.csdn.net/guyuealian/article/details/88343924
datasets这是一个pytorch定义的dataset的源码集合。下面是一个自定义Datasets的基本框架,初始化放在__init__()中,其中__getitem__()和__len__()两个方法是必须重写的。getitem()返回训练数据,如图片和label,而__len__()返回数据长度。
class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
DataLoader(object)可用参数:
通常,如果要自定义训练数据,那么在数据采集时最好以特定的格式去存储数据标签以及数据本身。
假设将数据标签存储在TXT文件中,则txt中每一行代表一个数据的,其形式通常为 “name + label”。示例如下图所示:
0.jpg 0
1.jpg 1
2.jpg 2
3.jpg 3
4.jpg 4
5.jpg 5
6.jpg 6
7.jpg 7
8.jpg 8
9.jpg 9
当然,也可以是多标签的数据:
0.jpg 0 10
1.jpg 1 11
2.jpg 2 12
3.jpg 3 13
4.jpg 4 14
5.jpg 5 15
6.jpg 6 16
7.jpg 7 17
8.jpg 8 18
9.jpg 9 19
自定义的Dataset类,需要继承torch.utils.data中的Dataset类,内部函数可根据自己的数据来自定义。
class Data_set(Dataset): def __init__(self, images_path, labels_path, transform=None): self.img_path = images_path self.transform = transform data = [] with open(labels_path, 'r') as f: lines = f.readlines() for line in lines: # use restrip() remove "\n, \r, \t, ' ' " content = line.rstrip().split(' ') content[1] = eval(content[1]) data.append([content[0],content[1]])#, content[1])) self.data = data def __getitem__(self, index): img_name, labels = self.data[index] image_path = os.path.join(self.img_path, img_name) img = Image.open(image_path).convert('RGB') if self.transform is not None: img = self.transform(img) labels1 = torch.from_numpy(np.array([labels[0] / 1280, labels[1] / 640], dtype=np.float32)) return img, labels1 def __len__(self): return len(self.data)
其中,index是不断进行迭代的,每一都加1,直到index = len(self.data)
transforms = transforms.Compose (
[transforms.Resize((224, 224)),
transforms.ToTensor()])
train_dataset = Data_set(img_path, label_path, transforms)
train_dataloader = DataLoader(train_dataset, shuffle=True,batch_size=batch_size, num_workers=8)
通过Data_set类来加载数据集,再通过DataLoader类来产生批量训练数据。其内部各参数的功能,在前文有提到。在这里提一下num_works,训练过程中,不同大小的num_works对训练速度影响较大,其主要是影响CPU向GPU传递数据的速度。而当数据传输较慢时,GPU很快能处理好接收到的数据,但CPU传输的较慢,使得总的训练时间加长(CPU传输数据时间+GPU处理数据时间),这也就出现训练过程中GPU的利用率出现较大波动。
当然,GPU利用率出现波动,也并不完全归结于num_works的大小,Dataset类的效率,也会对数据传输的速率产生影响。 更详细的介绍可以参考这篇博客.
(代码不可直接运行,仅仅作为框架参考)
class Data_set(Dataset): def __init__(self, images_path, labels_path, transform=None): self.img_path = images_path self.transform = transform data = [] with open(labels_path, 'r') as f: lines = f.readlines() for line in lines: # use restrip() remove "\n, \r, \t, ' ' " content = line.rstrip().split(' ') content[1] = eval(content[1]) data.append([content[0],content[1]])#, content[1])) self.data = data def __getitem__(self, index): img_name, labels = self.data[index] image_path = os.path.join(self.img_path, img_name) img = Image.open(image_path).convert('RGB') if self.transform is not None: img = self.transform(img) labels1 = torch.from_numpy(np.array([labels[0] / 1280, labels[1] / 640], dtype=np.float32)) return img, labels1 def __len__(self): return len(self.data) if __name__ == '__main__': transforms = transforms.Compose ( [transforms.Resize((224, 224)), transforms.ToTensor()]) train_dataset = Data_set(img_path, label_path, transforms) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=8)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。