赞
踩
在学习pytorch的过程中,用的一直都是教程中别人定义好从网上直接下载的数据集,不需要进行任何的处理,数据和标号都可以直接获取。但是,我想要进行自己的研究大多数情况需要我们自己收集数据并进行一些预处理在制作成数据集,然后通过pytorch读入后用来训练模型。这里记录的是一次对上万张验证码图片组成的数据集(标号是其名称)制作pytorch数据集的尝试。
部分数据如下:
大多数教程中并没有讲这些图片数据和标签是如何装载到torch中的,在分析了一个github项目https://github.com/braveryCHR/CNN_captcha 后我大概了解如何装载数据。
为了实现将验证码分类,我们先定义label和字符互相转换的函数:
import os import torch from PIL import Image from torch.utils import data import numpy as np from torch.utils.data import DataLoader from torchvision import transforms as T def StrToLabel(Str): # print(Str) label = [] for i in range(0, charNumber): if '0' <= Str[i] <= '9': # 数字 label.append(ord(Str[i]) - ord('0')) elif 'a' <= Str[i] <= 'z': # 小写字母 label.append(ord(Str[i]) - ord('a') + 10) else: # 大写字母 label.append(ord(Str[i]) - ord('A') + 36) return label def LabelToStr(Label): Str = "" for i in Label: if i <= 9: Str += chr(ord('0') + i) elif i <= 35: Str += chr(ord('a') + i - 10) else: Str += chr(ord('A') + i - 36) return Str
接下来是数据集合类的定义
class Captcha(data.Dataset): def __init__(self, root, train=True): self.imgPath = [os.path.join(root, img) for img in os.listdir(root)] self.transform = T.Compose([ T.Resize((150, 30)), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def __getitem__(self, index): img_path = self.imgPath[index] label = img_path.split('\\')[-1].split('.')[0] #获取图片标签 label_tensor = torch.Tensor(StrToLabel(label)) data=Image.open(img_path) data = self.transform(data) # 使用PLT打开图片文件 return data, label_tensor def __len__(self): return len(self.imgPath)
在init中的transform是预处理的定义。
getitem方法用来返回读取的图片数据和该图片的参数,我们将图片文件名获取到并转换为tensor,再使用PIL模块中的Image.open()读取图片数据,之后通过预处理transform转为tensor对象,最后返回图片数据data和图片标签label_tensor就可以了。
len函数返回文件中图片的数量。
dataloader会根据len读取文件中所有图片,每次读取图片的方法就是getitem中定义的方法。
我们来使用一下这个Capthca类,看看能否正确读取图片数据data以及其标号label
import os.path import torch import torchvision from torch import nn, optim import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms # 使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹 img_data = Captcha("./data/train/train", train=True) trainDataLoader = DataLoader(img_data, batch_size=1, shuffle=False, num_workers=4) if __name__ == '__main__': # for i, data in enumerate(trainDataLoader, 0): # inputs, label = data # print(label) it = trainDataLoader.__iter__()#使用迭代器返回第一张图片的数据和标签 data, label = it.next() print(data) print(label) print(LabelToStr(int(x)for x in label.squeeze().tolist()))
由于在jupyter中运行该代码会报错所以我放上在pycharm上的运行结果:
想要使用自己定义的数据集就必须实现一个dataset,使得dataloader知道如何获取数据以及标签。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。