当前位置:   article > 正文

在pytorch中使用自己的数据集,dataset的写法_用dataset生成测试fizbuz数据集

用dataset生成测试fizbuz数据集

引入

在学习pytorch的过程中,用的一直都是教程中别人定义好从网上直接下载的数据集,不需要进行任何的处理,数据和标号都可以直接获取。但是,我想要进行自己的研究大多数情况需要我们自己收集数据并进行一些预处理在制作成数据集,然后通过pytorch读入后用来训练模型。这里记录的是一次对上万张验证码图片组成的数据集(标号是其名称)制作pytorch数据集的尝试。

部分数据如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BF0drwOa-1647960599298)(attachment:51c03f9d-a46c-4994-8e7e-36c078c6a724.png)]

大多数教程中并没有讲这些图片数据和标签是如何装载到torch中的,在分析了一个github项目https://github.com/braveryCHR/CNN_captcha 后我大概了解如何装载数据。

方法

如果我们需要利用pytorch装载数据以及标签,我们就必须自己写一个dataset类,该类要继承data.Dataset类,该类在torch.utils中,并实现该类的_getitem_和_len_方法。
示例:

为了实现将验证码分类,我们先定义label和字符互相转换的函数:

import os

import torch
from PIL import Image
from torch.utils import data
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms as T

def StrToLabel(Str):
    # print(Str)
    label = []
    for i in range(0, charNumber):
        if '0' <= Str[i] <= '9':  # 数字
            label.append(ord(Str[i]) - ord('0'))
        elif 'a' <= Str[i] <= 'z':  # 小写字母
            label.append(ord(Str[i]) - ord('a') + 10)
        else:  # 大写字母
            label.append(ord(Str[i]) - ord('A') + 36)
    return label


def LabelToStr(Label):
    Str = ""
    for i in Label:
        if i <= 9:
            Str += chr(ord('0') + i)
        elif i <= 35:
            Str += chr(ord('a') + i - 10)
        else:
            Str += chr(ord('A') + i - 36)
    return Str
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

接下来是数据集合类的定义

class Captcha(data.Dataset):
    def __init__(self, root, train=True):
        self.imgPath = [os.path.join(root, img) for img in os.listdir(root)]
        self.transform = T.Compose([
            T.Resize((150, 30)),
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __getitem__(self, index):
        img_path = self.imgPath[index]
        label = img_path.split('\\')[-1].split('.')[0]       #获取图片标签
        label_tensor = torch.Tensor(StrToLabel(label))
        data=Image.open(img_path)
        data = self.transform(data)  # 使用PLT打开图片文件
        return data, label_tensor

    def __len__(self):
        return len(self.imgPath)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

在init中的transform是预处理的定义。

getitem方法用来返回读取的图片数据和该图片的参数,我们将图片文件名获取到并转换为tensor,再使用PIL模块中的Image.open()读取图片数据,之后通过预处理transform转为tensor对象,最后返回图片数据data和图片标签label_tensor就可以了。


len函数返回文件中图片的数量。


dataloader会根据len读取文件中所有图片,每次读取图片的方法就是getitem中定义的方法。

测试

我们来使用一下这个Capthca类,看看能否正确读取图片数据data以及其标号label

import os.path
import torch
import torchvision
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹

img_data = Captcha("./data/train/train", train=True)
trainDataLoader = DataLoader(img_data, batch_size=1,
                             shuffle=False, num_workers=4)

if __name__ == '__main__':
    # for i, data in enumerate(trainDataLoader, 0):
    #     inputs, label = data
    #     print(label)
    it = trainDataLoader.__iter__()#使用迭代器返回第一张图片的数据和标签
    data, label = it.next()
    print(data)
    print(label)
    print(LabelToStr(int(x)for x in label.squeeze().tolist()))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

由于在jupyter中运行该代码会报错所以我放上在pycharm上的运行结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bjSdGSsA-1647960599299)(attachment:f17724e4-267e-40ad-ab0d-5f041787eee2.png)]

总结

想要使用自己定义的数据集就必须实现一个dataset,使得dataloader知道如何获取数据以及标签。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/372589
推荐阅读
相关标签
  

闽ICP备14008679号