当前位置:   article > 正文

图像分类数据集FashionMNIST-动手学深度学习pytorch_d2l.load_data_fashion_mnist

d2l.load_data_fashion_mnist

1.加载图像集FashionMNIST

1.1导入相关库

注:运行环境:jupyter notebook

%matplotlib inline #jupyter notebook里面用,如果pycharm那么就直接注释掉
import torch
from torch.utils import data
from torchvision import transforms#常见的图片变换库
from d2l import torch as d2l
import os
os.environ["KMP_DUPLICATE_OK"]="TRUE"
d2l.use_svg_display()#
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

1.2加载MNIST数据集

如果数据集下载失败,请查看文章:Fashion-MNIST数据集本地下载及加载
重点:第一次download=True;第二次就不要下载了,之后设置download = False

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式
# 并除以255使得所有像素的数值均在0到1之间
# 第一次download=True;第二次之后设置download = False
# torchvision.datasets.FashionMNIST 直接用torchvision.datasets包下载数据集
# train=True表示:下载的是训练数据集;train=False表示:下载的是测试数据集
# transform=trans表示:拿出来后得到的是pytorch的tensor而不是一堆图片
trans = transforms.ToTensor() # 数据预处理,将PIL转换成张量
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
                                                transform=trans,
                                                download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False,
                                               transform=trans, download=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

1.3.查看下载数据的大小

len(mnist_train),len(mnist_test)
  • 1

在这里插入图片描述

1.4.每个像素的大小28 X 28

第一张图片的形状

mnist_train[0][0].shape
  • 1

在这里插入图片描述
1:表示RGB中表示一个通道,就是黑白图片
28,28表示长、宽

1.5.获取标签名字和对应的序号

Fashion-MNIST中包含的10个类别分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数用于在数字标签索引及其文本名称之间进行转换

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签。自定义并排序"""
    text_labels = [
        't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt',
        'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

1.6.定义函数show_images显示标签图片

matplotlib进行画图显示图片

def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):
    figsize = (num_cols *scale,num_rows*scale)
    _,axes = d2l.plt.subplots(num_rows,num_cols,figsize=figsize)
    axes = axes.flatten()
    for i,(ax,img) in enumerate(zip(axes,imgs)):
        if torch.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.inshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

1.7.显示图片

data.DataLoader数据加载类

x,y = next(iter(data.DataLoader(mnist_train,batch_size=18)))
show_images(x.reshape(18,28,28),2,9,titles=get_fashion_mnist_lables(y))
  • 1
  • 2

在这里插入图片描述

1.8.读取一小批量数据,大小为batch_size

使用4个进程来读取数据,shuffle=True表示可以随机读取数据

batch_size = 256

def get_dataloader_workers():
    
    return 4
train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers = get_dataloader_workers())

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

1.8.1测试加载数据时间:

timer = d2l.Timer()
for x,y in train_iter:
    continue
f'{timer.stop():.2f}sec'
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

1.9.整合所有组件

下载Fashion_MNIST数据集,然后将其加载导内存中

def load_data_fashion_mnist(batch_size,resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0,transforms.Resize(resize))
    trans=transforms.Compose(trans)
    mnist_train=torchvision.datasets.FashionMNIST(root="../data",
                                                  train=True,
                                                  transform=trans,
                                                  download=False)
    mnist_test=torchvision.datasets.FashionMNIST(root="../data",
                                                  train=False,
                                                  transform=trans,
                                                  download=False)
    return (data.DataLoader(mnist_train,batch_size,shuffle=True,
                             num_workers=get_dataloader_workers()),
             data.DataLoader(mnist_test,batch_size,shuffle=False,
                             num_workers=get_dataloader_workers()))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

1.10. 指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break
  • 1
  • 2
  • 3
  • 4

结果:

torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64
  • 1

2.softmax回归从零开始

2.1导入相关库

用d2l.load_data_fashion_mnist每次随机读取256张图片,并将结果返回给train_iter训练集迭代器,测试集迭代器test_iter

import torch
from IPython import display
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2.2初始化模型参数

由上面可知,每个样本图片为28X28的图像,通道数为1,那么这个是一个3D的输入,但是对于softmax回归来讲,我们需要的是一个向量,所以我们需要将整个图片拉长成一个向量784=28X28,这个过程就会损失掉很多空间信息,后面就只能交给卷积神经网络来处理。由于我们定义的标签数为10个,所以输出数据num_outputs为10.

num_inputs = 784
num_outputs = 10

w = torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)
b = torch.zeros(num_outputs,requires_grad=True)
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述

2.3 定义softmax操作

在这里插入图片描述

def softmax(x):
	x_exp = torch.exp(x)
	partition = x_exp.sum(1,keepdim=True)
	return x_exp/partition
  • 1
  • 2
  • 3
  • 4
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/73112
推荐阅读
相关标签
  

闽ICP备14008679号