赞
踩
我们经常会遇到这样的问题,就是如何使用自己的数据集,把标签和图片对应起来,然后转化成一个一个批次送进网络。其实在pytorch中已经为我们封装好了各种库,只需要我们添加相应的处理就好。
import torch.utils.data #子类化数据
import torch
from tochvision import transforms #数据处理
定义自己的dataset类:
class MyTrainData(torch.utils.data.Dataset)
该类包括初始化参数,index获取图片和标签,以及最后返回数据集的长度。
#encoding:utf-8 import torch.utils.data as data import torch from torchvision import transforms class MyTrainData(torch.utils.data.Dataset) #子类化 def __init__(self, root, transform=None, train=True): #第一步初始化各个变量 self.root = root self.train = train def __getitem__(self, idx): #第二步装载数据,返回[img,label],idx就是一张一张地读取 # get item 获取 数据 img = imread(img_path) #img_path根据自己的数据自定义,灵活性很高 img = torch.from_numpy(img).float() #需要转成float gt = imread(gt_path) #读取gt,如果是分类问题,可以根据文件夹或命名赋值 0 1 gt = torch.from_numpy(gt).float() return img, gt #返回 一一对应 def __len__(self): return len(self.imagenumber) #这个是必须返回的长度
transform类函数定义了各种数据的操作,比如旋转,剪切,缩放等等,我们可以根据自己的需要添加需要的函数。
1、完整代码
#encoding:utf-8 import torch.utils.data as data import torch from scipy.ndimage import imread import os import os.path import glob from torchvision import transforms def make_dataset(root, train=True): #读取自己的数据的函数 dataset = [] if train: dirgt = os.path.join(root, 'train_data/groundtruth') dirimg = os.path.join(root, 'train_data/imgs') for fGT in glob.glob(os.path.join(dirgt, '*.jpg')): # for k in range(45) fName = os.path.basename(fGT) fImg = 'train_ori'+fName[8:] dataset.append( [os.path.join(dirimg, fImg), os.path.join(dirgt, fName)] ) return dataset #自定义dataset的框架 class MyTrainData(data.Dataset): #需要繼承data.Dataset def __init__(self, root, transform=None, train=True): #初始化文件路進或文件名 self.train = train if self.train: self.train_set_path = make_dataset(root, train) def __getitem__(self, idx): if self.train: img_path, gt_path = self.train_set_path[idx] img = imread(img_path) img = np.atleast_3d(img).transpose(2, 0, 1).astype(np.float32) img = (img - img.min()) / (img.max() - img.min()) img = torch.from_numpy(img).float() gt = imread(gt_path) gt = np.atleast_3d(gt).transpose(2, 0, 1) gt = gt / 255.0 gt = torch.from_numpy(gt).float() return img, gt def __len__(self): return len(self.train_set_path)
2、另一种方式
import torch.nn.functional as F import torch import torch import torch.nn as nn from torch.autograd import Variable import torchvision.models as models from torchvision import transforms, utils from torch.utils.data import Dataset, DataLoader from PIL import Image import numpy as np import torch.optim as optim import os #torch.cuda.set_device(gpu_id)#使用GPU learning_rate = 0.0001 #数据集的设置***************************************************************************************************************** root =os.getcwd()+ '/data1/'#调用图像 #定义读取文件的格式 def default_loader(path): return Image.open(path).convert('RGB') #首先继承上面的dataset类。然后在__init__()方法中得到图像的路径,然后将图像路径组成一个数组,这样在__getitim__()中就可以直接读取: class MyDataset(Dataset): #创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset def __init__(self,txt, transform=None,target_transform=None, loader=default_loader): #初始化一些需要传入的参数 super(MyDataset,self).__init__()#对继承自父类的属性进行初始化 fh = open(txt, 'r')#按照传入的路径和txt文本参数,打开这个文本,并读取内容 imgs = [] for line in fh: #迭代该列表#按行循环txt文本中的内 line = line.strip('\n') line = line.rstrip('\n')# 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python words = line.split() #用split将该行分割成列表 split的默认参数是空格,所以不传递任何参数时分割空格 imgs.append((words[0],int(words[1]))) #把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定 # 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index):#这个方法是必须要有的,用于按照索引读取每个元素的具体内容 fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息 img = self.loader(fn) # 按照路径读取图片 if self.transform is not None: img = self.transform(img) #数据标签转换为Tensor return img,label#return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容 def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分 return len(self.imgs) #根据自己定义的那个MyDataset来创建数据集!注意是数据集!而不是loader迭代器 #*********************************************数据集读取完毕******************************************************************** #图像的初始化操作 train_transforms = transforms.Compose([ transforms.RandomResizedCrop((227,227)), transforms.ToTensor(), ]) text_transforms = transforms.Compose([ transforms.RandomResizedCrop((227,227)), transforms.ToTensor(), ]) #数据集加载方式设置 train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor()) test_data = MyDataset(txt=root+'text.txt', transform=transforms.ToTensor()) #然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关 train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True,num_workers=4) test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4) print('num_of_trainData:', len(train_data)) print('num_of_testData:', len(test_data))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。