赞
踩
在做计算机视觉相关任务,如图像分类时,需要使用PyTorch构建神经网络进行模型训练,这时候如果我们想使用自己的数据集进行训练,我们可以重构PyTroch的Dataset类来使用自己的数据集,或者按照ImageFolder类来加载自己的数据集,并使用DataLoder类创建Batch进行训练。
os模块内置的函数os.listdir()
可以遍历文件夹下的所有文件,而我们要将图片文件放在目标文件夹里
import os
os.listdir("D:/Desktop/Picture/")
# Out ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg', '6.jpg']
现在我们定义一下数据集的格式,相同类的图片存放在以类名作为文件夹名的文件夹内
然后我们重构一下torch的Dataset类
from torch.utils.data import Dataset as Dataset
首先定义一下数据标签,这个视自己的数据集而定
label_dic={"dog":0,"cat":1}
主要覆写Dataset
里的getitem
和len
方法
from PIL import Image class MyDataset(Dataset): def __init__(self,img_pth,label_dic,transform=None): super(MyDataset,self).__init__() label_ls=os.listdir(img_pth) self.img_pth = [] #存放完整的所有图片路径 for i in label_ls: for j in os.listdir(img_pth+"/"+i): self.img_pth.append(img_pth+"/"+i+"/"+j) self.transform = transform self.label_dic=label_dic #__getitem__会根据__len__创建一个迭代器 def __getitem__(self, index): img=Image.open(self.img_pth[index]) label=self.label_dic[self.img_pth[index].split("/")[-2]] #对图像数据进行标准化处理 if self.transform !=None: img=self.transform(img) return img,label def __len__(self): return len(self.img_pth)
此处未利用transform进行数据规约
torchvision.transforms
方法进行数据批处理import torchvision.transforms as transforms
transform=transforms.Compose([
transforms.RandomCrop(180), #长宽统一为180
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
归一化后
首先安装torchvision包
pip install torchvision
确保图片文件夹里每一类图片放在相同的文件夹里
from torchvision.datasets import ImageFolder
img_pth="D:/Desktop/Picture/" #文件夹路径
dataset=ImageFolder("D:/Desktop/Picture/")
print(dataset)
Out: Dataset ImageFolder
Number of datapoints: 12
Root location: D:/Desktop/Picture/
注意,使用ImageFolder函数会根据文件夹的读取顺序自动赋予标签【0~n】,如上图中cat文件夹优先于dog文件夹,所以cat里的图像标签为0,dog为1
可以使用以下语句查看类别名以及ImageFolder自动分配的类别和索引对应值
print(dataset.class_to_idx)
print(dataset.classes)
Out:{'cat': 0, 'dog': 1}
['cat', 'dog']
将上文的dataset直接传入DataLoader类
from torch.utils.data import DataLoader
dataloader=DataLoader(dataset,batch_size=2)
使用Dataloder主要是为了方便神经网络训练时进行批量梯度下降
查看dataloder内容
for data in dataloader:
img,label=data
print(img.shape,label)
可以看出,生成的dataloder数据变成了四维tensor
,标签则是一维向量的扩充(拼接)
实际上Dataset类只是一个容器,按照索引返回**(img,label)**,我们要对__getitem__
和__len__
方法进行修改,主要涉及到文件名的索引,所以我们要对os
模块的掌握较为熟悉,最后能成功返回一个迭代器即可。其与DataLoader的配合才更为重要,因为后者是神经网络训练中必不可少的数据处理方法。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。