赞
踩
Pytorch torchvision 包提供了很多常用数据集
数据按照用途一般分为三组:训练(train)、验证(validation)和测试(test)。使用训练数据集来训练模型,使用验证数据集跟踪模型在训练期间的性能,使用测试数据集对模型进行最终评估。
目录
从 torchvision导入MNIST训练数据集
- import torch
- import torchvision
- from torchvision import datasets
- train_data=datasets.MNIST("./data",train=True,download=True)
datasets.MNIST是Pytorch的内置函数
train=True指导入的数据作为训练数据集
download=True若根目录下没有数据集时自动下载
导入完成后可以看到MINST文件内的数据集
- x_train, y_train=train_data.data,train_data.targets
- print(x_train.shape)
- print(y_train.shape)
x_train存储60000张28*28的图片,y_train存储60000张图片对应的数字(label)
从 torchvision导入MNIST验证数据集并提取数据和标签
- val_data=datasets.MNIST("./data", train=False, download=True)
- x_val,y_val=val_data.data, val_data.targets
- print(x_val.shape)
- print(y_val.shape)
Pytorch中张量可以是一维、二维、三维或者更高维度的数据结构。一维张量类似于向量,二维张量类似于矩阵,三维张量类似一系列矩阵的堆叠。添加新的维度可以更好地对数据进行表示和处理。
- if len(x_train.shape)==3:
- x_train=x_train.unsqueeze(1)
- print(x_train.shape)
-
- if len(x_val.shape)==3:
- x_val=x_val.unsqueeze(1)
- print(x_val.shape)
.unsqueeze(0)指添加在第一个维度
也可以通过x_train.view(60000,1,28,28)添加维度
可以看到张量由三维变为了四维
引入所需的包,定义一个辅助函数,将张量显示为图像
- from torchvision import utils
- import matplotlib.pyplot as plt
- import numpy as np
- def show(img):
- npimg = img.numpy()
- npimg_tr=np.transpose(npimg, (1,2,0))
- plt.imshow(npimg_tr,interpolation='nearest')
创建一个10*10的网格,每行10张图片,pedding=3指间隔为3
- x_grid=utils.make_grid(x_train[:100], nrow=10, padding=3)
- print(x_grid.shape)
- show(x_grid)
utils.make_grid实际上是将多张图片拼接起来,参照官方介绍:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。