赞
踩
- import torch
- from torch.utils import data
- import os
- import numpy as np
- from PIL import Image
- #怎么制作数据集
- class dataset(data.Dataset):
- def __init__(self,path):
- self.path = path
- self.dataset =[] #当数据是较大的图片时,一次性不要全部加载进数据
- self.dataset.extend(os.listdir(path))#路径 路径里包含信息
- os.listdir()
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self,index):
- lable=torch.Tensor([int(self.dataset[index][0])])#取出标签 通过numpy转tensor不容易出错
- img_path=os.path.join(self.path,self.dataset[index])
- img=Image.open(img_path)
- img_data=torch.Tensor(np.array((img))/255-0.5)#/255归1化 -0.5去均值化 后转成Tensor
- return img_data,lable
-
- #验证一下数据
- if __name__=='__main__':
- #1看取出的数据是否有问题
- train_dataset = dataset('D:\workFile\深度学习_神经网络\img')
- x=train_dataset[1][0]
- y=train_dataset[1][1]
-
- #举证转回图像,看看是否有问题
- x2 = train_dataset[0][0].numpy()
- y2 = train_dataset[1][1].numpy()
-
- img_data=np.array((x2+0.5)*255,dtype=np.int8)
- img = Image.fromarray(img_data,"RGB")
- img.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。