当前位置:   article > 正文

Pytorch中Dataset类的继承使用方法(创建自己的图片数据集)(自己记录加强印象的,大家别喷我)_怎样继承自torch中的dataset

怎样继承自torch中的dataset

当我们有自己爬取的图片时,我们想要通过这些图片来训练自己的网络,该怎么办呢?(看不懂加我QQ探讨3014457121)

那就是继承Dataset类

想要实现这个类,必须要重写3个方法:init(self, 参数…)、 getitem(self, index)、len(self)

1、首先我们来介绍__init__(self, 参数…)
这个函数你可以传多个参数进来,以用来初始化你自己类中的各个变量。例如我传的是图片的母文件夹路径root

    def __init__(self, root):
        # 所有图片的绝对路径,root只定位到文件夹一级,os.listdir()只返回该文件下的文件名
        imgs = os.listdir(root)     # 这边先传入全部的数据
        self.imgs = [os.path.join(root, k) for k in imgs]
        self.transforms = None
        self.transforms = transform
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

imgs:拿到了所有图片的名字并形成一个list类型的列表。
然后类中的变量self.imgs就获得了所有图片的路径并形成一个列表。
下面我们来说说transforms变量是啥。

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),   # 把通道数扩大3倍
    transforms.Resize(600),         # 让最短边为600
    transforms.CenterCrop(600),     # 从中间切出600*600的图片
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

transforms.ToTensor():让图片变成pytorch框架所需要的张量形式(CHW维度)
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x):我爬取的图片里有些是1通道的,即C=1,但是pytorch必须要保持数据集的维度格式一模一样,所以我要把那些通道数是1的图片强行转为通道数为3
接下来2行看注释
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]):对得到的张量里的数字进行概率分布,用为这对训练友好。

2、再来看看__getitem__(self, index)方法:
主要用来根据传入的索引来返回某个数据,像这样用的:

train_dataset = FlameSet('./images')
for i in range(73):
     print(train_dataset[i].shape)
  • 1
  • 2
  • 3

下面是函数体部分:

    def __getitem__(self, index):   # 这个函数利用索引读出单独的数据
        img_path = self.imgs[index]
        pil_img = Image.open(img_path)
        if self.transforms:     # 有transform(最好用这个)
            data = self.transforms(pil_img)     # C, H, W
        else:                   # 无transform(尽量别用)
            pil_img = np.asarray(pil_img)
            data = torch.from_numpy(pil_img)    # H,W,C
            data = data.permute(2, 0, 1)        # ==> C,H,W 维度转变了

        return data
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

这个函数一般只需要传入一个index索引变量

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号