赞
踩
import numpy as np import torchvision.transforms as transforms from matplotlib import pyplot as plt from torch.utils.data import DataLoader from torchvision import datasets from torchvision.utils import make_grid transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=(0.485,0.455,0.406),std=(0.229,0.224,0.225)) ]) traindata = datasets.ImageFolder(root ='train', transform=transform) trainloader = DataLoader(dataset=traindata, batch_size=64, shuffle=True,num_workers=8) def image_show(image): plt.figure(figsize=(50,50)) image = image.numpy().transpose((1,2,0)) mean = np.array([0.485,0.456,0.406]) std = np.array([0.229,0.224,0.225]) image = std *image +mean image = np.clip(image,0,1) plt.imshow(image) plt.show() datas , targets = next(iter(trainloader)) out = make_grid(datas,nrow=4,padding=10) image_show(out) print(targets)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。