当前位置:   article > 正文

pytorch深度学习笔记(一)自定义数据集_深度学习自定义数据集

深度学习自定义数据集

参考链接:https://blog.csdn.net/guyuealian/article/details/88343924


在使用pytorch进行深度学习训练时,很多时候待训练的数据都是自己采集的,对于这一类数据我们需要使用pytorch中的Dataset和DataLoader类来进行封装,产生自定义的训练数据。

1、torch.utils.data.Dataset

datasets这是一个pytorch定义的dataset的源码集合。下面是一个自定义Datasets的基本框架,初始化放在__init__()中,其中__getitem__()和__len__()两个方法是必须重写的。getitem()返回训练数据,如图片和label,而__len__()返回数据长度。

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

2、torch.utils.data.DataLoader

DataLoader(object)可用参数:

  1. dataset(Dataset): 传入的数据集
  2. batch_size(int, optional): 每个batch有多少个样本
  3. shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
  4. sampler(Sampler, optional):自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必 须为False
  5. batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
  6. num_workers (int, optional): 这个参数决定了有几个进程来处理dataloading。0意味着所有的数据都会被load进主进程。(默认为0)
  7. collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
  8. pin_memory (bool, optional): 如果设置为True,那么data
    loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
  9. drop_last (bool, optional):如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了。如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
  10. timeout(numeric,optional):如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

3、数据格式

通常,如果要自定义训练数据,那么在数据采集时最好以特定的格式去存储数据标签以及数据本身。
假设将数据标签存储在TXT文件中,则txt中每一行代表一个数据的,其形式通常为 “name + label”。示例如下图所示:

0.jpg 0
1.jpg 1
2.jpg 2
3.jpg 3
4.jpg 4
5.jpg 5
6.jpg 6
7.jpg 7
8.jpg 8
9.jpg 9
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

当然,也可以是多标签的数据:

0.jpg 0 10
1.jpg 1 11
2.jpg 2 12
3.jpg 3 13
4.jpg 4 14
5.jpg 5 15
6.jpg 6 16
7.jpg 7 17
8.jpg 8 18
9.jpg 9 19
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

4、自定义Dataset类

自定义的Dataset类,需要继承torch.utils.data中的Dataset类,内部函数可根据自己的数据来自定义。

class Data_set(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.img_path = images_path
        self.transform = transform
        data = []
        with open(labels_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                # use restrip() remove "\n, \r, \t, ' ' "
                content = line.rstrip().split(' ')
                content[1] = eval(content[1])
                data.append([content[0],content[1]])#, content[1]))
                self.data = data

    def __getitem__(self, index):
        img_name, labels = self.data[index]
        image_path = os.path.join(self.img_path, img_name)
        img = Image.open(image_path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        labels1 = torch.from_numpy(np.array([labels[0] / 1280, labels[1] / 640], dtype=np.float32))
        return img, labels1

    def __len__(self):
        return len(self.data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

其中,index是不断进行迭代的,每一都加1,直到index = len(self.data)

5、使用DataLoader产生批量训练数据

transforms = transforms.Compose (
        [transforms.Resize((224, 224)),
         transforms.ToTensor()])
train_dataset = Data_set(img_path, label_path, transforms)
train_dataloader = DataLoader(train_dataset, shuffle=True,batch_size=batch_size, num_workers=8)
                                  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

通过Data_set类来加载数据集,再通过DataLoader类来产生批量训练数据。其内部各参数的功能,在前文有提到。在这里提一下num_works,训练过程中,不同大小的num_works对训练速度影响较大,其主要是影响CPU向GPU传递数据的速度。而当数据传输较慢时,GPU很快能处理好接收到的数据,但CPU传输的较慢,使得总的训练时间加长(CPU传输数据时间+GPU处理数据时间),这也就出现训练过程中GPU的利用率出现较大波动。
当然,GPU利用率出现波动,也并不完全归结于num_works的大小,Dataset类的效率,也会对数据传输的速率产生影响。 更详细的介绍可以参考这篇博客.

6、整体代码为

(代码不可直接运行,仅仅作为框架参考)

class Data_set(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.img_path = images_path
        self.transform = transform
        data = []
        with open(labels_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                # use restrip() remove "\n, \r, \t, ' ' "
                content = line.rstrip().split(' ')
                content[1] = eval(content[1])
                data.append([content[0],content[1]])#, content[1]))
                self.data = data

    def __getitem__(self, index):
        img_name, labels = self.data[index]
        image_path = os.path.join(self.img_path, img_name)
        img = Image.open(image_path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        labels1 = torch.from_numpy(np.array([labels[0] / 1280, labels[1] / 640], dtype=np.float32))
        return img, labels1

    def __len__(self):
        return len(self.data)

if __name__ == '__main__':
    transforms = transforms.Compose (
        [transforms.Resize((224, 224)),
         transforms.ToTensor()])

    train_dataset = Data_set(img_path, label_path, transforms)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size,
                                  num_workers=8)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/267956
推荐阅读
相关标签
  

闽ICP备14008679号