当前位置:   article > 正文

DataLoader自定义数据集制作_dataloader 自己的数据

dataloader 自己的数据

如何自定义数据集:

- 1.数据和标签的目录结构先搞定(得知道到哪读数据)
- 2.写好读取数据和标签路径的函数(根据自己数据集情况来写)
- 3.完成单个数据与标签读取函数(给dataloader举一个例子)

以花朵数据集为例:

- 原来数据集都是以文件夹为类别ID,现在咱们换一个套路,用txt文件指定数据路径与标签(实际情况基本都这样)
- 这回咱们的任务就是在txt文件中获取图像路径与标签,然后把他们交给dataloader
- 核心代码非常简单,按照对应格式传递需要的数据和标签就可以啦

任务1:读取txt文件中的路径和标签

  • 第一个小任务,从标注文件中读取数据和标签
  • 至于你准备存成什么格式,都可以的,一会能取出来东西就行

任务2:分别把数据和标签都存在list里

 - 不是我非让你存list里,因为dataloader到时候会在这里取数据
- 按照人家要求来,不要耍个性,让整list咱就给人家整

 任务3:图像数据路径得完整
- 因为一会咱得用这个路径去读数据,所以路径得加上前缀
- 以后大家任务不同,数据不同,怎么加你看着来就行,反正得能读到图像

任务4:把上面那几个事得写在一起

- 1.注意要使用from torch.utils.data import Dataset, DataLoader
- 2.类名定义class FlowerDataset(Dataset),其中FlowerDataset可以改成自己的名字
- 3.def __init__(self, root_dir, ann_file, transform=None):咱们要根据自己任务重写
- 4.def __getitem__(self, idx):根据自己任务,返回图像数据和标签数据

任务5:数据预处理(transform)

- 1.预处理的事都在上面的__getitem__中完成,需要对图像和标签咋咋地的,要整啥事,都在上面整
- 2.返回的数据和标签就是建模时模型的输入和损失函数中标签的输入,一定整明白自己模型要啥
- 3.预处理这个事是你定的,不同的数据需要的方法也不一样,下面给出的是比较通用的方法

 任务6:根据写好的class FlowerDataset(Dataset):来实例化咱们的dataloader

- 1.构建数据集:分别创建训练和验证用的数据集(如果需要测试集也一样的方法)
- 2.用Torch给的DataLoader方法来实例化(batch啥的自己定,根据你的显存来选合适的)
- 3.打印看看数据里面是不是有东西了

 

任务7:用之前先试试,整个数据和标签对应下

- 1.别着急往模型里传,对不对都不知道呢
- 2.用这个方法:iter(train_loader).next()来试试,得到的数据和标签是啥
- 3.看不出来就把图画出来,标签打印出来,确保自己整的数据集没啥问题

 代码实现

  1. import os
  2. from matplotlib import pyplot as plt
  3. from torchvision import transforms, models, datasets
  4. import numpy as np
  5. import torch
  6. from PIL import Image
  7. def load_annotations(ann_file):
  8. data_infos = {}
  9. with open(ann_file) as f:
  10. samples = [x.strip().split(' ') for x in f.readlines()]
  11. for filename, gt_label in samples:
  12. data_infos[filename] = np.array(gt_label, dtype=np.int64)
  13. return data_infos
  14. img_label =load_annotations('./flower_data/train.txt')
  15. image_name = list(img_label.keys())
  16. label = list(img_label.values())
  17. data_dir = './flower_data/'
  18. train_dir = data_dir + '/train_filelist'
  19. valid_dir = data_dir + '/val_filelist'
  20. image_path = [os.path.join(train_dir,img) for img in image_name]
  21. from torch.utils.data import Dataset, DataLoader
  22. class FlowerDataset(Dataset):
  23. def __init__(self, root_dir, ann_file, transform=None):
  24. self.ann_file = ann_file
  25. self.root_dir = root_dir
  26. self.img_label = self.load_annotations()
  27. self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
  28. self.label = [label for label in list(self.img_label.values())]
  29. self.transform = transform
  30. def __len__(self):
  31. return len(self.img)
  32. def __getitem__(self, idx):
  33. image = Image.open(self.img[idx])
  34. label = self.label[idx]
  35. if self.transform:
  36. image = self.transform(image)
  37. label = torch.from_numpy(np.array(label))
  38. return image, label
  39. def load_annotations(self):
  40. data_infos = {}
  41. with open(self.ann_file) as f:
  42. samples = [x.strip().split(' ') for x in f.readlines()]
  43. for filename, gt_label in samples:
  44. data_infos[filename] = np.array(gt_label, dtype=np.int64)
  45. return data_infos
  46. data_transforms = {
  47. 'train':
  48. transforms.Compose([
  49. transforms.Resize(64),
  50. transforms.RandomRotation(45),#随机旋转,-4545度之间随机选
  51. transforms.CenterCrop(64),#从中心开始裁剪
  52. transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
  53. transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
  54. transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
  55. transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
  56. transforms.ToTensor(),
  57. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
  58. ]),
  59. 'valid':
  60. transforms.Compose([
  61. transforms.Resize(64),
  62. transforms.CenterCrop(64),
  63. transforms.ToTensor(),
  64. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  65. ]),
  66. }
  67. train_dataset = FlowerDataset(root_dir=train_dir, ann_file = './flower_data/train.txt', transform=data_transforms['train'])
  68. val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = './flower_data/val.txt', transform=data_transforms['valid'])
  69. train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  70. val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
  71. image, label = next(iter(train_loader))
  72. sample = image[0].squeeze()
  73. sample = sample.permute((1, 2, 0)).numpy()
  74. sample *= [0.229, 0.224, 0.225]
  75. sample += [0.485, 0.456, 0.406]
  76. plt.imshow(sample)
  77. plt.show()
  78. print('Label is: {}'.format(label[0].numpy()))

 

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/829417
推荐阅读
相关标签
  

闽ICP备14008679号