赞
踩
pytorch虽然提供了torchvision.datasets包,封装了一些常用的数据集供我们很方便地调用,但我们经常需要训练自己的图像数据,构建并加载数据集往往是训练神经网络的第一步,本文将介绍如何构建加载自己的图像数据集,并用于神经网络输入。
1. 数据集的文件结构:
train为数据集根目录,下一级为每个类别的文件夹,分别包含着若干张图像:
2. torch.utils.data.Dataset:
Dataset是表示数据集的抽象类,当我们自定义数据集时应继承Dataset类,并重写以下方法 :
1. __getitem__: 支持根据给定的key来获取数据样本。
2. __len__: 实现返回数据集的数据数量。
构建的自定义数据集类如下:
- import torch
- from torch.utils.data import Dataset, DataLoader
- from pathlib import Path
- import cv2
-
-
- def get_images_and_labels(dir_path):
- '''
- 从图像数据集的根目录dir_path下获取所有类别的图像名列表和对应的标签名列表
- :param dir_path: 图像数据集的根目录
- :return: images_list, labels_list
- '''
- dir_path = Path(dir_path)
- classes = [] # 类别名列表
-
- for category in dir_path.iterdir():
- if category.is_dir():
- classes.append(category.name)
- images_list = [] # 文件名列表
- labels_list = [] # 标签列表
-
- for index, name in enumerate(classes):
- class_path = dir_path / name
- if not class_path.is_dir():
- continue
- for img_path in class_path.glob('*.jpg'):
- images_list.append(str(img_path))
- labels_list.append(int(index))
- return images_list, labels_list
-
-
- class MyDataset(Dataset):
- def __init__(self, dir_path, transform=None):
- self.dir_path = dir_path # 数据集根目录
- self.transform = transform
- self.images, self.labels = get_images_and_labels(self.dir_path)
-
- def __len__(self):
- # 返回数据集的数据数量
- return len(self.images)
-
- def __getitem__(self, index):
- img_path = self.images[index]
- label = self.labels[index]
- img = cv2.imread(img_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- sample = {'image': img, 'label': label}
- if self.transform:
- sample['image'] = self.transform(sample['image'])
- return sample
1. torch.utils.data.DataLoader:
DataLoader是一个数据集加载器类,提供了很多方便的数据集操作,比如shuffle,batch,drop_last等,详细用法可参考文档。
- if __name__ == '__main__':
- train_dataset = MyDataset(r"C:\Users\admin\Desktop\set100-10\annotated_camera_images")
- dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
- for index, batch_data in enumerate(dataloader):
- print(index, batch_data['image'].shape, batch_data['label'].shape)
运行的结果如下所示:
- 0 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 1 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 2 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 3 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 4 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 5 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 6 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 7 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 8 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 9 torch.Size([64, 224, 224, 3]) torch.Size([64])
- 10 torch.Size([36, 224, 224, 3]) torch.Size([36])
-
- Process finished with exit code 0
至此,就完成了图像数据集的构建与加载。
在自定义的数据集类中,我们设置了一个参数transform但没有用,接下来就来介绍如何定义transform用于图像增强:
1. torchvision.transforms:
torchvision.transforms中实现了许多常见的数据增强操作,比如Scale, Crop, Resize, Normalize, ColorJitter等等等等,可以浏览transforms.py查看所有操作。这里直接定义一个返回transforms.Compose的方法(有关transforms.Compose可以参考https://blog.csdn.net/PanYHHH/article/details/106045153):
- def get_transform_for_train():
- transform_list = []
- transform_list.append(transforms.ToPILImage())
- transform_list.append(transforms.RandomHorizontalFlip(p=0.3))
- transform_list.append(transforms.ColorJitter(0.1, 0.1, 0.1, 0.1))
- transform_list.append(transforms.ToTensor())
- transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
- return transforms.Compose(transform_list)
然后在实例化自定义数据集类的时候,就可以将transform作为参数传入:
- if __name__ == '__main__':
- train_dataset = MyDataset(r"C:\Users\admin\Desktop\set100-10\annotated_camera_images", get_transform_for_train())
- dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
- for index, batch_data in enumerate(dataloader):
- print(index, batch_data['image'].shape, batch_data['label'].shape)
运行结果为:
- 0 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 1 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 2 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 3 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 4 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 5 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 6 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 7 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 8 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 9 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 10 torch.Size([36, 3, 224, 224]) torch.Size([36])
-
- Process finished with exit code 0
后记:
到这里就结束了,欢迎讨论指正。
5月17号更新:
今天学习了torchvision.datasets.ImageFolder,利用这个类可以很方便定义一个通用的数据加载器,它所要求的图像排列结构如下,和文章一开始说的结构是一样的:
- '''
- root/dog/xxx.png
- root/dog/xxy.png
- root/dog/xxz.png
- root/cat/123.png
- root/cat/nsdf3.png
- root/cat/asd932_.png
- '''
然后只需定义号一个transforms,就可以实现数据集的加载:
- if __name__ == '__main__':
- train_dataset = ImageFolder(root=r'C:\Users\admin\Desktop\set100-10\annotated_camera_images\train', transform=get_transform_for_train())
- dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
- for index, batch_data in enumerate(dataloader):
- print(index, batch_data[0].shape, batch_data[1].shape)
输出结果为:
- 0 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 1 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 2 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 3 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 4 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 5 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 6 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 7 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 8 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 9 torch.Size([64, 3, 224, 224]) torch.Size([64])
- 10 torch.Size([36, 3, 224, 224]) torch.Size([36])
-
- Process finished with exit code 0
确实非常方便。。一般情况下用这个就够了
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。