当前位置:   article > 正文

可视化学习笔记9-pytorch cifar10数据可视化,归一化可视化。_cifar10数据集下载可视化

cifar10数据集下载可视化

cifar10数据可视化

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import torch
import torchvision
import torchvision.transforms as transforms


#下载数据预处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform )
trainloader=torch.utils.data.DataLoader(trainset,batch_size=5,shuffle=True,num_workers=0)
testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=False,transform=transform )
testloader=torch.utils.data.DataLoader(testset,batch_size=5,shuffle=False)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')#10个类

#可视化
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img=img/2+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.savefig('./img1.png')
    plt.show()

    #随机获取部分训练数据
dataiter=iter(testloader)  #依次取出迭代器里的值。执行一次只能取到迭代器里的一个值
images,labels=dataiter.next()
#显示图像
imshow(torchvision.utils.make_grid(images))
print(''.join('%5s'% classes[labels[j]] for j in range(4) ))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

在这里插入图片描述

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
torch.cuda.set_device(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 40
Epochs = 250


trans = torchvision.transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
cifar_10 = torchvision.datasets.CIFAR10(root='./data', train=True, transform=trans, download=False)
data_loader = DataLoader(cifar_10, batch_size=batch_size, shuffle=True)

def imshow(img):

    #反归一化,将数据重新映射到0-1之间
    img = img / 2 + 0.5
    plt.figure(figsize=(10, 4))
    plt.imshow(np.transpose(img.numpy(), (1,2,0)))
    plt.show()


for i, (images, _) in enumerate(data_loader):

    print(i)
    print(images.numpy().shape)
    # plt.subplot(4, 10, images + 1)
    imshow(torchvision.utils.make_grid(images))
    break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/214856
推荐阅读
相关标签
  

闽ICP备14008679号