当前位置:   article > 正文

Pytorch制作自己的数据集

pytorch制作自己的数据集

1、Dataset+DataLoader实现自定义数据集读取方法
创建自己的数据集需要继承父类torch.utils.data.Dataset,同时需要重载两个私有成员函数:def __len__(self)和def __getitem__(self, index) 。 def __len__(self)应该返回数据集的大小;def __getitem__(self, index)接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。

 基于下面的模板写类和方法 来加载自己的图像

  1. class MyDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Dataset
  2. def __init__(self):
  3. #对继承自父类的属性进行初始化(好像没有这句也可以??)
  4. super(MyDataset,self).__init__()
  5. # TODO
  6. #1、初始化一些参数和函数,方便在__getitem__函数中调用。
  7. #2、制作__getitem__函数所要用到的图片和对应标签的list。
  8. #也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
  9. pass
  10. def __getitem__(self, index):
  11. # TODO
  12. #1、根据list从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
  13. #2、预处理数据(例如torchvision.Transform)。
  14. #3、返回数据对(例如图像和标签)。
  15. #这里需要注意的是,这步所处理的是index所对应的一个样本。
  16. pass
  17. def __len__(self):
  18. #返回数据集大小
  19. return len()

下面是一个  按照上面的模板  写的例子

第一步:先收集几张图片作为自己的数据集,然后自己手动创建一个txt文件(关于txt文件,通过python和matlab可以很容易的创建)储存图片对应的label。

图像照片

 txt文件

第二步:按照上面的模板,制作自己的数据集类。 

  1. import torch
  2. import torchvision
  3. from torchvision import transforms
  4. from PIL import Image
  5. from torch.utils.data import Dataset
  6. from torch.utils.data import DataLoader
  7. #路径是自己电脑里所对应的路径
  8. datapath = r'E:\Python\DeepLearning\Datasets\testdata'
  9. txtpath = r'E:\Python\DeepLearning\Datasets\testdata\label.txt'
  10. class MyDataset(Dataset):
  11. def __init__(self,txtpath):
  12. #创建一个list用来储存图片和标签信息
  13. imgs = []
  14. #打开第一步创建的txt文件,按行读取,将结果以元组方式保存在imgs里
  15. datainfo = open(txtpath,'r')
  16. for line in datainfo:
  17. line = line.strip('\n')
  18. words = line.split()
  19. imgs.append((words[0],words[1]))
  20. self.imgs = imgs
  21. #返回数据集大小
  22. def __len__(self):
  23. return len(self.imgs)
  24. #打开index对应图片进行预处理后return回处理后的图片和标签
  25. def __getitem__(self, index):
  26. pic,label = self.imgs[index]
  27. pic = Image.open(datapath+'\\'+pic)
  28. pic = transforms.ToTensor()(pic)
  29. return pic,label
  30. #实例化对象
  31. data = MyDataset(txtpath)
  32. #将数据集导入DataLoader,进行shuffle以及选取batch_size
  33. data_loader = DataLoader(data,batch_size=2,shuffle=True,num_workers=0)
  34. #Windows里num_works只能为0,其他值会报错

代码很简单,你可以根据自己的实际数据集情况,修改上面的代码。

查看一下data_loader:

  1. for pics,label in data_loader:
  2. print(pics,label)

输出如下

  1. tensor([[[[0., 0., 0., ..., 0., 0., 0.],
  2. [0., 0., 0., ..., 0., 0., 0.],
  3. [0., 0., 0., ..., 0., 0., 0.],
  4. ...,
  5. [0., 0., 0., ..., 0., 0., 0.],
  6. [0., 0., 0., ..., 0., 0., 0.],
  7. [0., 0., 0., ..., 0., 0., 0.]]],
  8. [[[0., 0., 0., ..., 0., 0., 0.],
  9. [0., 0., 0., ..., 0., 0., 0.],
  10. [0., 0., 0., ..., 0., 0., 0.],
  11. ...,
  12. [0., 0., 0., ..., 0., 0., 0.],
  13. [0., 0., 0., ..., 0., 0., 0.],
  14. [0., 0., 0., ..., 0., 0., 0.]]]]) ('4', '2')
  15. tensor([[[[0., 0., 0., ..., 0., 0., 0.],
  16. [0., 0., 0., ..., 0., 0., 0.],
  17. [0., 0., 0., ..., 0., 0., 0.],
  18. ...,
  19. [0., 0., 0., ..., 0., 0., 0.],
  20. [0., 0., 0., ..., 0., 0., 0.],
  21. [0., 0., 0., ..., 0., 0., 0.]]],
  22. [[[0., 0., 0., ..., 0., 0., 0.],
  23. [0., 0., 0., ..., 0., 0., 0.],
  24. [0., 0., 0., ..., 0., 0., 0.],
  25. ...,
  26. [0., 0., 0., ..., 0., 0., 0.],
  27. [0., 0., 0., ..., 0., 0., 0.],
  28. [0., 0., 0., ..., 0., 0., 0.]]]]) ('2', '1')
  29. tensor([[[[0., 0., 0., ..., 0., 0., 0.],
  30. [0., 0., 0., ..., 0., 0., 0.],
  31. [0., 0., 0., ..., 0., 0., 0.],
  32. ...,
  33. [0., 0., 0., ..., 0., 0., 0.],
  34. [0., 0., 0., ..., 0., 0., 0.],
  35. [0., 0., 0., ..., 0., 0., 0.]]],
  36. [[[0., 0., 0., ..., 0., 0., 0.],
  37. [0., 0., 0., ..., 0., 0., 0.],
  38. [0., 0., 0., ..., 0., 0., 0.],
  39. ...,
  40. [0., 0., 0., ..., 0., 0., 0.],
  41. [0., 0., 0., ..., 0., 0., 0.],
  42. [0., 0., 0., ..., 0., 0., 0.]]]]) ('0', '1')
  43. tensor([[[[0., 0., 0., ..., 0., 0., 0.],
  44. [0., 0., 0., ..., 0., 0., 0.],
  45. [0., 0., 0., ..., 0., 0., 0.],
  46. ...,
  47. [0., 0., 0., ..., 0., 0., 0.],
  48. [0., 0., 0., ..., 0., 0., 0.],
  49. [0., 0., 0., ..., 0., 0., 0.]]],
  50. [[[0., 0., 0., ..., 0., 0., 0.],
  51. [0., 0., 0., ..., 0., 0., 0.],
  52. [0., 0., 0., ..., 0., 0., 0.],
  53. ...,
  54. [0., 0., 0., ..., 0., 0., 0.],
  55. [0., 0., 0., ..., 0., 0., 0.],
  56. [0., 0., 0., ..., 0., 0., 0.]]]]) ('4', '0')
  57. tensor([[[[0., 0., 0., ..., 0., 0., 0.],
  58. [0., 0., 0., ..., 0., 0., 0.],
  59. [0., 0., 0., ..., 0., 0., 0.],
  60. ...,
  61. [0., 0., 0., ..., 0., 0., 0.],
  62. [0., 0., 0., ..., 0., 0., 0.],
  63. [0., 0., 0., ..., 0., 0., 0.]]],
  64. [[[0., 0., 0., ..., 0., 0., 0.],
  65. [0., 0., 0., ..., 0., 0., 0.],
  66. [0., 0., 0., ..., 0., 0., 0.],
  67. ...,
  68. [0., 0., 0., ..., 0., 0., 0.],
  69. [0., 0., 0., ..., 0., 0., 0.],
  70. [0., 0., 0., ..., 0., 0., 0.]]]]) ('3', '3')

data_loader这个迭代器里存储的是每2个一组(batch_size)的图片像素信息以及对应的标签信息,也就是我们后续要导入到神经网络里的数据。

只是一个最简单的用来加深理解的例子,实际应用时会比这复杂很多,比如图像的渲染、变换等,但是基本流程都是一样的。

1.3、txt文件的生成

这个 很简单,用python脚本,写一个就行。

2、ImageFolder+DataLoader实现本地数据导入(这个很方便)
在pytorch中提供了torchvision.datasets.ImageFolder让我们训练自己的图像。ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它主要有四个参数:
root:在root指定的路径下寻找图片
transform:对loader读取图片的返回对象进行转换操作(ToTensor等)
target_transform:对label的转换
loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

文件夹严格按照如下方式保存:
 

  1. .
  2. ├──train
  3. | ├──类别1
  4. | | ├──*.jpg
  5. | | ├──*.jpg
  6. | | └──...
  7. | ├──类别2
  8. | | ├──*.jpg
  9. | | ├──*.jpg
  10. | | └──...
  11. | └──...
  12. └──test
  13. ├──类别1
  14. | ├──*.jpg
  15. | ├──*.jpg
  16. | └──...
  17. ├──类别2
  18. | ├──*.jpg
  19. | ├──*.jpg
  20. | └──...
  21. └──...

实现代码如下

  1. transform = transforms.ToTensor()
  2. root = r'E:\Python\DeepLearning\Datasets\mymnist\train'
  3. # 使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
  4. train_data = torchvision.datasets.ImageFolder(root, transform=transform)
  5. train_iter = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True, num_workers=0)
  6. test_data = torchvision.datasets.ImageFolder(root, transform=transform)
  7. test_iter = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=0)

参考:

(87条消息) Pytorch创建自己的数据集(一)_生活所迫^_^的博客-CSDN博客_pytorch构建自己的数据集

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/548400
推荐阅读
相关标签
  

闽ICP备14008679号