当前位置:   article > 正文

PyTorch自制数据集--使用自己的数据集_pytorch制作自己的图片数据集

pytorch制作自己的图片数据集

PyTorch 自制数据集

在做计算机视觉相关任务,如图像分类时,需要使用PyTorch构建神经网络进行模型训练,这时候如果我们想使用自己的数据集进行训练,我们可以重构PyTroch的Dataset类来使用自己的数据集,或者按照ImageFolder类来加载自己的数据集,并使用DataLoder类创建Batch进行训练。

重构Dataset类实现读取自己的数据集

os模块内置的函数os.listdir()可以遍历文件夹下的所有文件,而我们要将图片文件放在目标文件夹里

import os
os.listdir("D:/Desktop/Picture/")
# Out ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg', '6.jpg']
  • 1
  • 2
  • 3

在这里插入图片描述
现在我们定义一下数据集的格式,相同类的图片存放在以类名作为文件夹名的文件夹内
在这里插入图片描述
然后我们重构一下torch的Dataset类

from torch.utils.data import Dataset as Dataset
  • 1

首先定义一下数据标签,这个视自己的数据集而定

label_dic={"dog":0,"cat":1}
  • 1

主要覆写Dataset里的getitemlen方法

from PIL import Image
class MyDataset(Dataset):
    def __init__(self,img_pth,label_dic,transform=None):
        super(MyDataset,self).__init__()
        label_ls=os.listdir(img_pth)
        self.img_pth = []	#存放完整的所有图片路径
        for i in label_ls:
            for j in os.listdir(img_pth+"/"+i):
                self.img_pth.append(img_pth+"/"+i+"/"+j)
        self.transform = transform
        self.label_dic=label_dic
    #__getitem__会根据__len__创建一个迭代器
    def __getitem__(self, index):
        img=Image.open(self.img_pth[index])
        label=self.label_dic[self.img_pth[index].split("/")[-2]]
        #对图像数据进行标准化处理
        if self.transform !=None:
            img=self.transform(img)
        return img,label
        
    def __len__(self):
        return len(self.img_pth)
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

此处未利用transform进行数据规约
在这里插入图片描述

使用torchvision.transforms方法进行数据批处理

import torchvision.transforms as transforms
transform=transforms.Compose([
    transforms.RandomCrop(180),	#长宽统一为180
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

归一化后
在这里插入图片描述

torchvision.datasets.ImageFolder

首先安装torchvision包

pip install torchvision
  • 1

确保图片文件夹里每一类图片放在相同的文件夹里
在这里插入图片描述

from torchvision.datasets import ImageFolder

img_pth="D:/Desktop/Picture/"	#文件夹路径
dataset=ImageFolder("D:/Desktop/Picture/")
print(dataset)

Out:	Dataset ImageFolder
	    Number of datapoints: 12
	    Root location: D:/Desktop/Picture/
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

注意,使用ImageFolder函数会根据文件夹的读取顺序自动赋予标签【0~n】,如上图中cat文件夹优先于dog文件夹,所以cat里的图像标签为0,dog为1
可以使用以下语句查看类别名以及ImageFolder自动分配的类别和索引对应值

print(dataset.class_to_idx)
print(dataset.classes)

Out:{'cat': 0, 'dog': 1}
	['cat', 'dog']
  • 1
  • 2
  • 3
  • 4
  • 5

DataLoader类生成Batch进行训练

将上文的dataset直接传入DataLoader类

from torch.utils.data import DataLoader

dataloader=DataLoader(dataset,batch_size=2)
  • 1
  • 2
  • 3

使用Dataloder主要是为了方便神经网络训练时进行批量梯度下降
查看dataloder内容

for data in dataloader:
	img,label=data
	print(img.shape,label)
  • 1
  • 2
  • 3

在这里插入图片描述
可以看出,生成的dataloder数据变成了四维tensor,标签则是一维向量的扩充(拼接)

小结

实际上Dataset类只是一个容器,按照索引返回**(img,label)**,我们要对__getitem____len__方法进行修改,主要涉及到文件名的索引,所以我们要对os模块的掌握较为熟悉,最后能成功返回一个迭代器即可。其与DataLoader的配合才更为重要,因为后者是神经网络训练中必不可少的数据处理方法。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/364852
推荐阅读
相关标签
  

闽ICP备14008679号