当前位置:   article > 正文

pytorch构建并加载自己的图像数据集_torch 自己建立训练图像

torch 自己建立训练图像

前言:

pytorch虽然提供了torchvision.datasets包,封装了一些常用的数据集供我们很方便地调用,但我们经常需要训练自己的图像数据,构建并加载数据集往往是训练神经网络的第一步,本文将介绍如何构建加载自己的图像数据集,并用于神经网络输入。

一. 自定义图像数据集:

1. 数据集的文件结构:

train为数据集根目录,下一级为每个类别的文件夹,分别包含着若干张图像:

2. torch.utils.data.Dataset:

Dataset是表示数据集的抽象类,当我们自定义数据集时应继承Dataset类,并重写以下方法 :

1. __getitem__: 支持根据给定的key来获取数据样本。

2. __len__: 实现返回数据集的数据数量。

构建的自定义数据集类如下:

  1. import torch
  2. from torch.utils.data import Dataset, DataLoader
  3. from pathlib import Path
  4. import cv2
  5. def get_images_and_labels(dir_path):
  6. '''
  7. 从图像数据集的根目录dir_path下获取所有类别的图像名列表和对应的标签名列表
  8. :param dir_path: 图像数据集的根目录
  9. :return: images_list, labels_list
  10. '''
  11. dir_path = Path(dir_path)
  12. classes = [] # 类别名列表
  13. for category in dir_path.iterdir():
  14. if category.is_dir():
  15. classes.append(category.name)
  16. images_list = [] # 文件名列表
  17. labels_list = [] # 标签列表
  18. for index, name in enumerate(classes):
  19. class_path = dir_path / name
  20. if not class_path.is_dir():
  21. continue
  22. for img_path in class_path.glob('*.jpg'):
  23. images_list.append(str(img_path))
  24. labels_list.append(int(index))
  25. return images_list, labels_list
  26. class MyDataset(Dataset):
  27. def __init__(self, dir_path, transform=None):
  28. self.dir_path = dir_path # 数据集根目录
  29. self.transform = transform
  30. self.images, self.labels = get_images_and_labels(self.dir_path)
  31. def __len__(self):
  32. # 返回数据集的数据数量
  33. return len(self.images)
  34. def __getitem__(self, index):
  35. img_path = self.images[index]
  36. label = self.labels[index]
  37. img = cv2.imread(img_path)
  38. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  39. sample = {'image': img, 'label': label}
  40. if self.transform:
  41. sample['image'] = self.transform(sample['image'])
  42. return sample

二. 加载数据集:

1. torch.utils.data.DataLoader:

DataLoader是一个数据集加载器类,提供了很多方便的数据集操作,比如shuffle,batch,drop_last等,详细用法可参考文档。

  1. if __name__ == '__main__':
  2. train_dataset = MyDataset(r"C:\Users\admin\Desktop\set100-10\annotated_camera_images")
  3. dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
  4. for index, batch_data in enumerate(dataloader):
  5. print(index, batch_data['image'].shape, batch_data['label'].shape)

运行的结果如下所示:

  1. 0 torch.Size([64, 224, 224, 3]) torch.Size([64])
  2. 1 torch.Size([64, 224, 224, 3]) torch.Size([64])
  3. 2 torch.Size([64, 224, 224, 3]) torch.Size([64])
  4. 3 torch.Size([64, 224, 224, 3]) torch.Size([64])
  5. 4 torch.Size([64, 224, 224, 3]) torch.Size([64])
  6. 5 torch.Size([64, 224, 224, 3]) torch.Size([64])
  7. 6 torch.Size([64, 224, 224, 3]) torch.Size([64])
  8. 7 torch.Size([64, 224, 224, 3]) torch.Size([64])
  9. 8 torch.Size([64, 224, 224, 3]) torch.Size([64])
  10. 9 torch.Size([64, 224, 224, 3]) torch.Size([64])
  11. 10 torch.Size([36, 224, 224, 3]) torch.Size([36])
  12. Process finished with exit code 0

至此,就完成了图像数据集的构建与加载。

三. 对图像进行数据增强:

在自定义的数据集类中,我们设置了一个参数transform但没有用,接下来就来介绍如何定义transform用于图像增强:

1. torchvision.transforms:

torchvision.transforms中实现了许多常见的数据增强操作,比如Scale, Crop, Resize, Normalize, ColorJitter等等等等,可以浏览transforms.py查看所有操作。这里直接定义一个返回transforms.Compose的方法(有关transforms.Compose可以参考https://blog.csdn.net/PanYHHH/article/details/106045153):

  1. def get_transform_for_train():
  2. transform_list = []
  3. transform_list.append(transforms.ToPILImage())
  4. transform_list.append(transforms.RandomHorizontalFlip(p=0.3))
  5. transform_list.append(transforms.ColorJitter(0.1, 0.1, 0.1, 0.1))
  6. transform_list.append(transforms.ToTensor())
  7. transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
  8. return transforms.Compose(transform_list)

然后在实例化自定义数据集类的时候,就可以将transform作为参数传入:

  1. if __name__ == '__main__':
  2. train_dataset = MyDataset(r"C:\Users\admin\Desktop\set100-10\annotated_camera_images", get_transform_for_train())
  3. dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  4. for index, batch_data in enumerate(dataloader):
  5. print(index, batch_data['image'].shape, batch_data['label'].shape)

运行结果为:

  1. 0 torch.Size([64, 3, 224, 224]) torch.Size([64])
  2. 1 torch.Size([64, 3, 224, 224]) torch.Size([64])
  3. 2 torch.Size([64, 3, 224, 224]) torch.Size([64])
  4. 3 torch.Size([64, 3, 224, 224]) torch.Size([64])
  5. 4 torch.Size([64, 3, 224, 224]) torch.Size([64])
  6. 5 torch.Size([64, 3, 224, 224]) torch.Size([64])
  7. 6 torch.Size([64, 3, 224, 224]) torch.Size([64])
  8. 7 torch.Size([64, 3, 224, 224]) torch.Size([64])
  9. 8 torch.Size([64, 3, 224, 224]) torch.Size([64])
  10. 9 torch.Size([64, 3, 224, 224]) torch.Size([64])
  11. 10 torch.Size([36, 3, 224, 224]) torch.Size([36])
  12. Process finished with exit code 0

后记:

到这里就结束了,欢迎讨论指正。


5月17号更新:

今天学习了torchvision.datasets.ImageFolder,利用这个类可以很方便定义一个通用的数据加载器,它所要求的图像排列结构如下,和文章一开始说的结构是一样的:

  1. '''
  2. root/dog/xxx.png
  3. root/dog/xxy.png
  4. root/dog/xxz.png
  5. root/cat/123.png
  6. root/cat/nsdf3.png
  7. root/cat/asd932_.png
  8. '''

然后只需定义号一个transforms,就可以实现数据集的加载:

  1. if __name__ == '__main__':
  2. train_dataset = ImageFolder(root=r'C:\Users\admin\Desktop\set100-10\annotated_camera_images\train', transform=get_transform_for_train())
  3. dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  4. for index, batch_data in enumerate(dataloader):
  5. print(index, batch_data[0].shape, batch_data[1].shape)

输出结果为:

  1. 0 torch.Size([64, 3, 224, 224]) torch.Size([64])
  2. 1 torch.Size([64, 3, 224, 224]) torch.Size([64])
  3. 2 torch.Size([64, 3, 224, 224]) torch.Size([64])
  4. 3 torch.Size([64, 3, 224, 224]) torch.Size([64])
  5. 4 torch.Size([64, 3, 224, 224]) torch.Size([64])
  6. 5 torch.Size([64, 3, 224, 224]) torch.Size([64])
  7. 6 torch.Size([64, 3, 224, 224]) torch.Size([64])
  8. 7 torch.Size([64, 3, 224, 224]) torch.Size([64])
  9. 8 torch.Size([64, 3, 224, 224]) torch.Size([64])
  10. 9 torch.Size([64, 3, 224, 224]) torch.Size([64])
  11. 10 torch.Size([36, 3, 224, 224]) torch.Size([36])
  12. Process finished with exit code 0

确实非常方便。。一般情况下用这个就够了

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/262418
推荐阅读
相关标签
  

闽ICP备14008679号