当前位置:   article > 正文

PyTorch中如何读取数据_pytorch里的读取数据的代码

pytorch里的读取数据的代码

一、Dataset函数的使用

        当我们得到一个数据集时,Dataset类可以帮我们提取我们需要的数据,我们用子类继承Dataset类,我们先给每个数据一个编号(idx),在后面的神经网络中,初始化Dataset子类实例后,就可以通过这个编号去实例对象中读取相应的数据,会自动调用__getitem__方法,同时子类对象也会获取相应真实的Label(人为去复写即可)

Dataset类的作用:提供一种方式去获取数据及其对应的真实Label

在Dataset类的子类中,应该有以下函数以实现某些功能:

  1. 获取每一个数据及其对应的Label(数据和label可能在同一个文件中,比如分别对应图片内容和图片名也可能分别在两个文件中)
  2. 统计数据集中总共的数据数量

具体使用例子:

        新创建一个项目,在项目中新建一个python文件命名为read_data,将数据集datas保存在该py文件所在的目录下。

【这里使用的视频去哪了?-创建者去哪了?-播单去哪了?-哔哩哔哩视频视频下的蚂蚁蜜蜂/练手数据集:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码: 5suq】

【上面那个B站链接能打开】

代码如下:

(由于昨天妄想从头学python,发现课程基本从整型浮点型讲,太墨迹,决定还是决定边实战边学,都放在注释里。有错的地方烦请指正,假如有人看的话。)

  1. from torch.utils.data import Dataset as Dataset
  2. # from...import...: 选择性地导入库中的一个或多个函数、变量和类,as重新命名
  3. # torch.utils.data模块包含一些常用的数据预处理的操作,主要用于数据的读取、切分、准备;Dataset是其中的一个抽象类
  4. from PIL import Image
  5. # PIL:Python Imaging Library,是一个Python的第三方库,用于图像处理
  6. '''
  7. 这几行代码用于测试Image用法,能够成功打开图片
  8. img_path = "E:\\pyProject\\pythonProject\\learn_dataset\\datas\\train\\ants_image\\7759525_1363d24e88.jpg"
  9. img_path = "datas\\train\\ants_image\\7759525_1363d24e88.jpg"
  10. 上面两个一个是绝对路径一个是相对路径,都能够正常使用。(相对路径应该只是和py文件在同一文件夹下能用))
  11. 注意win系统下是用的双\\,表示转义,一个会报错
  12. img = Image.open(img_path)
  13. img.show()
  14. '''
  15. import os
  16. # os 是 operation system 的缩写,支持文件和目录操作,进程管理,环境变量管理等功能
  17. # 定义一个类, 继承了 Dataset 这个类
  18. class MyData(Dataset):
  19. # 构造函数,初始化类,self表示创建的实例本身,在__init__方法内部可以把各种属性绑定到self,为整个类提供全局变量
  20. # root_dir
  21. def __init__(self, root_dir, label_dir):
  22. self.root_dir = root_dir
  23. self.label_dir = label_dir
  24. self.path = os.path.join(self.root_dir, self.label_dir) # 将两个路径拼接,获得图片路径地址
  25. self.img_path = os.listdir(self.path) # 获得所有图片的列表
  26. # 成员函数,在Dataset子类中都必须重写__getitem__方法,该方法用于获取label
  27. # idx是index缩写,指的索引
  28. def __getitem__(self, idx):
  29. img_name = self.img_path[idx] # 为啥是【】,获取索引值对应的图片名字
  30. img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 获取每个图片的位置
  31. img = Image.open(img_item_path)
  32. label = self.label_dir
  33. return img, label
  34. # 重写__len__方法,返回数据集长度
  35. def __len__(self):
  36. return len(self.img_path)
  37. root_dir = "datas/val"
  38. ants_label_dir = "ants"
  39. bees_label_dir = "bees"
  40. # 对于类创建了蚂蚁和蜜蜂两个实例(数据集)
  41. ants_dataset = MyData(root_dir, ants_label_dir)
  42. bees_dataset = MyData(root_dir, bees_label_dir)
  43. # 输出两个数据集的长度。注意同时输出字符串和数字时,需要将数字强制转换为字符串类型
  44. # 在python中print自带换行啊,不需要\n换行
  45. print("蚂蚁数据集长度:"+str(len(ants_dataset)))
  46. print("蜜蜂数据集长度:"+str(len(bees_dataset)))
  47. '''
  48. 测试
  49. img1, label = ants_dataset[0]
  50. img1.show()
  51. img2, label = bees_dataset[0]
  52. img2.show()
  53. '''
  54. # 将两个数据集拼接起来,得到两个数据集的集合
  55. # 这个方法可以用来仿造大的数据集
  56. train_dataset = ants_dataset + bees_dataset
  57. '''
  58. 测试,该数据集前70个是蚂蚁的,后83个是蜜蜂的
  59. img, label = train_dataset[69]
  60. img.show()
  61. img, label = train_dataset[70]
  62. img.show()
  63. '''

         在有的时候label很长,就要用一个专门的文件来保存,一般是在另一个文件夹中与图片同名的txt文件内。

二、DataLoader类的使用

dataloader:构建可迭代的数据装载器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。

这个还没学呢,下次再说。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/455145
推荐阅读
相关标签
  

闽ICP备14008679号