当前位置:   article > 正文

手把手教你利用PyTorch实现图像识别_如何利用pytorch实现图像识别

如何利用pytorch实现图像识别

本篇文章完整代码可以在我的公众号【拇指笔记】后台回复"softmax_py"获取
文末有二维码~

识别效果:

1.softmax回归

这一部分分为softmax回归模型的概念、图像分类数据集的概念、softmax回归模型的实现和softmax回归模型基于pytorch框架的实现四部分。

对于离散值预测问题,我们可以使用诸如softmax回归这样的分类模型。softmax回归模型有多个输出单元。本章以softmax回归模型为例,介绍神经网络中的分类模型。

1.1分类问题

例如一个简单的图像分类问题,输入图形高和宽都为2像素,且色彩为灰度(灰度图像的像素值可以用一个标量来表示)。我们将图像的四个像素值记为x1,x2,x3,x4。假设训练数据集中图像的真实标签为狗 猫和鸡,这些标签分别对应着离散值y1,y2,y3。

我们通常使用离散值来表示类别,例如y1=1,y2=2,y3=3。一张图像的标签为1、2和3的数值中的一个,对于这种问题,我们一般使用更加适合离散输出的模型来解决分类问题。

1.2softmax回归模型

softmax回归模型一样将输入特征与权重做线性叠加。于线性回归的主要区别为softmax回归的输出值个数等于标签里的类别数。

在上面的例子中,每个图像又四个像素,对应着每个图象有四个特征值(x),有三种可能的动物类别,对应着三

个离散值标签(o)。所以包含12个权重(w)和3个偏差(b)
o 1 = w 11 x 1 + w 21 x 2 + w 31 x 3 + w 41 x 4 + b 1 , o 2 = w 12 x 1 + w 22 x 2 + w 32 x 3 + w 42 x 4 + b 2 , o 3 = w 13 x 1 + w 23 x 2 + w 33 x 3 + w 43 x 4 + b 3 , w 下 标 命 名 规 则 : 不 同 列 代 表 不 同 输 出 类 型 , 不 同 行 代 表 不 同 像 素 点 。 列 数 代 表 真 实 输 出 的 类 别 数 ; 行 数 代 表 特 征 数 。 o_1=w_{11}x_1+w_{21}x_2+w_{31}x_3+w_{41}x_4+b_1, \\o_2=w_{12}x_1+w_{22}x_2+w_{32}x_3+w_{42}x_4+b_2, \\o_3=w_{13}x_1+w_{23}x_2+w_{33}x_3+w_{43}x_4+b_3, \\w下标命名规则: \\不同列代表不同输出类型,不同行代表不同像素点。 \\列数代表真实输出的类别数;行数代表特征数。 o1=w11x1+w21x2+w31x3+w41x4+b1,o2=w12x1+w22x2+w32x3+w42x4+b2,o3=w13x1+w23x2+w33x3+w43x4+b3,w
softmax回归也是一个单层神经网络,每个输出o的计算都要依赖所有的输入x,所以softmax回归的输出层也是一个全连接层。

在这里插入图片描述

通常将输出值 oi 作为预测类别 i 的置信度,并将值最大的输出所对应的类作为预测输出
a r g i m a x o i arg_imaxo_i argimaxoi
例如o1,o2,o3分别为0.1,10,0.1由于o2最大,那么预测类别为2。

但这种方法有两个问题

  1. 输出层的输出值的范围不确定,难以只管判断这些值的意义

    如:三个值为0.1,10,0.1时,10代表很置信;但当三个值为1000,10,1000时,10又代表不置信。

  2. 由于真实标签也是离散值,这些离散值于不确定范围的输出值之间的误差难以衡量。

softmax运算符解决了以上两个问题。它通过下式将输出值转化为值为正且和为1的概率分布。
y 1 ^ , y 2 ^ , y 3 ^ = s o f t m a x ( o 1 , o 2 , o 3 ) \hat{y_1},\hat{y_2},\hat{y_3}=softmax(o_1,o_2,o_3) y1^,y2^,y3^=softmax(o1,o2,o3)
其中
y 1 ^ = e x p ( 0 1 ) ∑ i = 1 3 e x p ( x i ) ,    y 2 ^ = e x p ( 0 2 ) ∑ i = 1 3 e x p ( x i ) ,    y 3 ^ = e x p ( 0 3 ) ∑ i = 1 3 e x p ( x i ) \hat{y_1}=\frac{exp(0_1)}{\sum_{i=1}^3exp(xi)},\ \ \hat{y_2}=\frac{exp(0_2)}{\sum_{i=1}^3exp(xi)},\ \ \hat{y_3}=\frac{exp(0_3)}{\sum_{i=1}^3exp(xi)} y1^=i=13exp(xi)exp(01),  y2^=i=13exp(xi)exp(02),  y3^=i=13exp(xi)exp(03)
非常容易看出
y 1 ^ + y 2 ^ + y 3 ^ = 1 且 0 ≤ y 1 ^ , y 2 ^ , y 3 ^ ≤ 1 \hat{y_1}+\hat{y_2}+\hat{y_3}=1 \\且0\leq\hat{y_1},\hat{y_2},\hat{y_3}\leq1 y1^+y2^+y3^=10y1^,y2^,y3^1
基于上两式可知,y1,y2,y3是合法的概率分布。例如:y2=0.8那么不管y1,y3是多少,我们都知道为第二个类别的概率为80%

由于
a r g i m a x o i = a r g i m a x y i ^ arg_imaxo_i = arg_imax\hat{y_i} argimaxoi=argimaxyi^
可以知道,softmax运算不改变预测类别输出。

1.3单样本分类的矢量计算表达式

为了提高运算效率,采用矢量计算。以上面的图像分类问题为例权重和偏差参数的矢量表达式为
W = { w 11   w 12   w 13 w 21   w 22   w 23 w 31   w 32   w 33 w 41   w 42   w 43 } ,    b = [ b 1   b 2   b 3 ] W = \left\{

w11 w12 w13w21 w22 w23w31 w32 w33w41 w42 w43
\right\} ,\ \ b=[b_1 \ b_2\ b_3] W=w11 w12 w13w21 w22 w23w31 w32 w33w41 w42 w43,  b=[b1 b2 b3]
设高和宽分别为2个像素的图像样本 i 的特征为
x ( i ) = [ x 1 ( i )   x 2 ( i )   x 3 ( i )   x 4 ( i ) ] x^{(i)}=[x^{(i)}_1 \ x^{(i)}_2 \ x^{(i)}_3 \ x^{(i)}_4] x(i)=[x1(i) x2(i) x3(i) x4(i)]
输出层输出为
o i = [ o 1 i   o 2 i   o 3 i ] o^{i} = [o_1^{i} \ o_2^{i} \ o_3^{i}] oi=[o1i o2i o3i]
预测的概率分布为
y ^ ( i ) = [ y ^ 1 ( i )   y ^ 2 ( i )   y ^ 3 ( i ) ] \hat{y}^{(i)}=[\hat{y}^{(i)}_1 \ \hat{y}^{(i)}_2 \ \hat{y}^{(i)}_3] y^(i)=[y^1(i) y^2(i) y^3(i)]
最终得到softmax回归对样本 i 分类的矢量计算表达式为
o ( i ) = x ( i ) W + b y ^ ( i ) = s o f t m a x ( o ( i ) ) o^{(i)}=x^{(i)}W+b \\ \hat{y}^{(i)}=softmax(o^{(i)}) o(i)=x(i)W+by^(i)=softmax(o(i))
对于给定的小批量样本,存在
O = X W + b Y ^ = s o f t m a x ( O ) O = XW+b \\\hat{Y}=softmax(O) O=XW+bY^=softmax(O)

1.4交叉熵损失函数

使用softmax运算后可以更方便地于离散标签计算误差。真实标签同样可以变换为一个合法的概率分布,即:对于一个样本(一个图像),它的真实类别为y_i,我们就令y_i为1,其余为0。如图像为猫(第二个),则它的y = [0 1 0 ]。这样就可以使\hat{y}更接近y。

在图像分类问题中,想要预测结果正确并不需要让预测概率与标签概率相等(不同动作 颜色的猫),我们只需要让真实类别对应的概率大于其他类别的概率即可,因此不必使用线性回归模型中的平方损失函数。

我们使用交叉熵函数来计算损失。
H ( y ( i ) , y ^ ( i ) ) = − ∑ j = 1 q y j ( i ) l o g   y ^ j ( i ) H(y^{(i)},\hat{y}^{(i)})=-\sum_{j=1}^q y_j^{(i)}log\ \hat{y}^{(i)}_j H(y(i),y^(i))=j=1qyj(i)log y^j(i)
这个式子中,y^(i) _j 是真实标签概率中的为1的那个元素,而 \hat{y}^{(i)}_j 是预测得到的类别概率中与之对应的那个元素。

由于在y(i)中只有一个标签,因此在y{i}中,除了y^(i) _j 外,其余元素都为0,于是得到上式的简化方程
H ( y ( i ) , y ^ ( i ) ) = − l o g   y ^ j ( i ) H(y^{(i)},\hat{y}^{(i)}) =- log\ \hat{y}^{(i)}_j H(y(i),y^(i))=log y^j(i)
也就是说交叉熵函数只与预测到的概率数有关,只要预测得到的值够大,就可以确保分类结果的正确性。

对于整体样本而言,交叉熵损失函数定义为
l ( θ ) = 1 n ∑ i = 1 n H ( y ( i ) , y ^ ( i ) ) l(\theta) =\frac{1}{n} \sum_{i=1}^n H(y^{(i)},\hat{y}^{(i)}) l(θ)=n1i=1nH(y(i),y^(i))
其中\theta代表模型参数,如果每个样本都只有一个标签,则上式可以简化为
l ( θ ) = − 1 n ∑ i = 1 n l o g   y ^ j ( i ) l(\theta) =-\frac{1}{n} \sum_{i=1}^nlog\ \hat{y}^{(i)}_j l(θ)=n1i=1nlog y^j(i)
最小化交叉熵损失函数等价于最大化训练数据集所有标签类别的联合预测概率 。

2.图像分类数据集(Fashion-MNIST

这一章节需要用到torchvision包,为此,我重装了

这个数据集是我们在后面学习中将会用到的图形分类数据集。它的图像内容相较于手写数字识别数据集MINIST更为复杂一些,更加便于我们直观的观察算法之间的差异。

这一节主要使用torchvision包,主要用来构建计算机视觉模型。

torchvision包的主要构成功能
torchvision.datasets一些加载数据的函数及常用数据集接口
torchvision.madels包含常用的模型结构(含预训练模型)
torchvision.transforms常用的图片变换(裁剪、旋转)
torchvision.utils其他方法

2.1获取数据集

首先导入需要的包

import torch 
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..")	
#调用库时,sys.path会自动搜索路径,为了导入d2l这个库,所以需要添加".."
#import d2lzh_pytorch as d2l	这个库找不到不用了
from IPython import display
#在这一节d2l库仅仅在绘图时被使用,因此使用这个库做替代
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

**通过调用torchvision中的torchvision.datasets来下载这个数据集。**第一次调用从网上自动获取数据。

通过设置参数train来制定获取训练数据集或测试数据集(测试集:用来评估模型表现,并不用来训练模型)。

通过设置参数transfrom = transforms.ToTensor()将所有数据转换成Tensor,如果不进行转换则返回PIL图片。

transforms.ToTensor()函数将尺寸为(H*W*C)且数据位于[0,255]之间的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C*H*W)且数据类型为torch.float32且位于[0,0,1.0]的Tensor

C代表通道数,灰度图像的通道数为1

PIL图片是python处理图片的标准

注意:transforms.ToTensor()函数默认将输入类型设置为uint8

#获取训练集
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download = True,transform = transforms.ToTensor())
#获取测试集
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download = True,transform = transforms.ToTensor())
  • 1
  • 2
  • 3
  • 4

其中mnist_train和mnist_test可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本。

训练集和测试集都有10个类别,训练集中每个类别的图像数为6000,测试集中每个类别的图像数为1000,即:训练集中有60000个样本,测试集中有10000个样本。

len(mnist_train)	#输出训练集的样本数
mnist_train[0]		#通过下标访问任意一个样本,返回值为两个torch,一个特征tensor和一个标签tensor
  • 1
  • 2

Fashion-MNIST数据集中共有十个类别,分别为: t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴) 。

需要将这些文本标签和数值标签相互转换,可以通过以下函数进行。

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
	#labels是一个列表
	#数值标签转文本标签
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

下面是一个可以在意行里画出多张图像和对应标签的函数

def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
   	#绘制矢量图
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    #创建子图,一行len(images)列,图片大小12*12
    for f, img, lbl in zip(figs, images, labels):
        #zip函数将他们压缩成由多个元组组成的列表
        f.imshow(img.view((28, 28)).numpy())
        #将img转形为28*28大小的张量,然后转换成numpy数组
        f.set_title(lbl)
        #设置每个子图的标题为标签
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
        #关闭x轴y轴
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

上述函数的使用

X,y = [],[]
#初始化两个列表
for i in range(10):
	X.append(mnist_train[i][0])
	#循环向X列表添加图像
	y.append(mnist_train[i][1])
	#循环向y列表添加标签
show_fashion_mnist(X,get_fashion_mnist_labels(y))
#显示图像和列表
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2.2读取小批量

有了线性回归中读取小批量的经验,我们知道读取小批量可以使用torch中内置的dataloader函数来实现。

dataloader还支持多线程读取数据,通过设置它的num_workers参数。

batch_size = 256
#小批量数目
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle = True,num_workers = 0)
#num_workers=0,不开启多线程读取。
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size = batch_size,shuffle=False,num_workers=0)
  • 1
  • 2
  • 3
  • 4
  • 5

3. 使用pytorch实现softmax回归模型

使用pytorch可以更加便利的实现softmax回归模型。

3.1 获取和读取数据

读取小批量数据的方法:

  1. 首先是获取数据,pytorch可以通过以下代码很方便的获取Fashion-MNIST数据集。

    mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
    
    mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
    
    #参数
    
    #root : processed/training.pt 和 processed/test.pt 的主目录 
    #train : True = 训练集, False = 测试集
    #download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下
    #transform = transforms.ToTensor():使所有数据转换为Tensor
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
  2. 然后是生成一个迭代器,用来读取数据

    #生成迭代器
    train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle = True,num_workers = 0)
    
    test_iter = torch.utils.data.DataLoader(mnist_test,batch_size = batch_size,shuffle=False,num_workers=0)
    #参数
    
    #dataset:Dataset类型,从其中加载数据
    #batch_size:int类型,每个批量加载多少个数
    #shuffle:bool类型,每个学习周期都打乱顺序
    #num_workers:int类型,加载数据时使用多少子进程。默认值为0.
    #collate_fn:定义如何取样本,可通过定义自己的函数来实现。
    #pin_memory:锁页内存处理。
    #drop_last:bool类型,如果有剩余的样本,True表示丢弃;Flase表示不丢弃
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

3.2 定义和初始化模型

由softmax回归模型的定义可知,softmax回归模型只有权重参数和偏差参数。因此可以使用神经网络子模块中的线性模块。
o 1 = w 11 x 1 + w 21 x 2 + w 31 x 3 + w 41 x 4 + b 1 , o 2 = w 12 x 1 + w 22 x 2 + w 32 x 3 + w 42 x 4 + b 2 , o 3 = w 13 x 1 + w 23 x 2 + w 33 x 3 + w 43 x 4 + b 3 , o_1=w_{11}x_1+w_{21}x_2+w_{31}x_3+w_{41}x_4+b_1, \\o_2=w_{12}x_1+w_{22}x_2+w_{32}x_3+w_{42}x_4+b_2, \\o_3=w_{13}x_1+w_{23}x_2+w_{33}x_3+w_{43}x_4+b_3, o1=w11x1+w21x2+w31x3+w41x4+b1,o2=w12x1+w22x2+w32x3+w42x4+b2,o3=w13x1+w23x2+w33x3+w43x4+b3,

  1. 首先定义网络,softmax回归是一个两层的网络,所以只需要定义输入层和输出层即可。
num_inputs = 784
num_outputs = 10

class LinearNet(nn.Module):
    def __init__(self,num_inputs,num_outputs):
        super(LinearNet,self).__init__()
        self.linear = nn.Linear(num_inputs,num_outputs)
        #定义一个输入层
        
    #定义向前传播(在这个两层网络中,它也是输出层)
    def forward(self,x):
        y = self.linear(x.view(x.shape[0],-1))
        #将x换形为y后,再继续向前传播
        return y
    
net = LinearNet(num_inputs,num_outputs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  1. 初始化参数

使用torch.nn中的init可以快速的初始化参数。我们令权重参数为均值为0,标准差为0.01的正态分布。偏差为0。

init.normal_(net.linear.weight, mean=0, std=0.01)
init.constant_(net.linear.bias, val=0) 
  • 1
  • 2

3.3 softmax运算和交叉熵损失函数

分开定义softmax运算和交叉熵损失函数会造成数值不稳定。因此PyTorch提供了一个具有良好数值稳定性且包括softmax运算和交叉熵计算的函数。

loss = nn.CrossEntropyLoss()
  • 1

3.4 定义优化算法

依然使用小批量随机梯度下降作为优化算法。定义学习率为0.1。

optimizer = torch.optim.SGD(net.parameters(),lr=0.01)
  • 1

3.5 计算分类准确率

计算准确率的原理:

我们把预测概率最大的类别作为输出类别,如果它与真实类别y一致,说明预测正确。分类准确率就是正确预测数量与总预测数量之比

首先我们需要得到预测的结果。

从一组预测概率(变量y_hat)中找出最大的概率对应的索引(索引即代表了类别)

#argmax(f(x))函数,对f(x)求最大值所对应的点x。我们令f(x)= dim=1,即可实现求所有行上的最大值对应的索引。
A = y_hat.argmax(dim=1)	
#最终输出结果为一个行数与y_hat相同的列向量
  • 1
  • 2
  • 3

然后我们需要将得到的最大概率对应的类别与真实类别(y)比较,判断预测是否是正确的

B = (y_hat.argmax(dim=1)==y).float()
#由于y_hat.argmax(dim=1)==y得到的是ByteTensor型数据,所以我们通过.float()将其转换为浮点型Tensor()
  • 1
  • 2

最后我们需要计算分类准确率

我们知道y_hat的行数就对应着样本总数,所以,对B求平均值得到的就是分类准确率

(y_hat.argmax(dim=1)==y).float().mean()
  • 1

上一步最终得到的数据为tensor(x)的形式,为了得到最终的pytorch number,需要对其进行下一步操作

(y_hat.argmax(dim=1)==y).float().mean().item()
#pytorch number的获取统一通过.item()实现
  • 1
  • 2

整理一下,得到计算分类准确率函数

def accuracy(y_hat,y):
    return (y_hat.argmax(dim=1).float().mean().item())
  • 1
  • 2

作为推广,该函数还可以评价模型net在数据集data_iter上的准确率。

def net_accurary(data_iter,net):
    right_sum,n = 0.0,0
    for X,y in data_iter:
    #从迭代器data_iter中获取X和y
        right_sum += (net(X).argmax(dim=1)==y).float().sum().item()
        #计算准确判断的数量
        n +=y.shape[0]
        #通过shape[0]获取y的零维度(列)的元素数量
    return right_sum/n
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

3.6 训练模型

num_epochs = 5
#一共进行五个学习周期

def train_softmax(net,train_iter,test_iter,loss,num_epochs,batch_size,optimizer,net_accurary):
    for epoch in range(num_epochs):
        #损失值、正确数量、总数 初始化。
        train_l_sum,train_right_sum,n= 0.0,0.0,0
        
        for X,y in train_iter:
            y_hat = net(X)
            l = loss(y_hat,y).sum()
            #数据集损失函数的值=每个样本的损失函数值的和。            
            optimizer.zero_grad()			#对优化函数梯度清零
            l.backward()	#对损失函数求梯度
            optimizer(params,lr,batch_size)
            
            train_l_sum += l.item()
            train_right_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
            
        test_acc = net_accurary(test_iter, net)	#测试集的准确率
        print('第%d学习周期, 误差%.4f, 训练准确率%.3f, 测试准确率%.3f' % (epoch + 1, train_l_sum / n, train_right_sum / n, test_acc))
        
train_softmax(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,optimizernet_accurary,net_accurary)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

训练效果

3.7 图像分类

使用训练好的模型对测试集进行预测

做一个模型的最终目的当然不是训练了,所以来预测一下试试。

#将样本的类别数字转换成文本
def get_Fashion_MNIST_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
    #labels是一个列表,所以有了for循环获取这个列表对应的文本列表

#显示图像
def show_fashion_mnist(images,labels):
    display.set_matplotlib_formats('svg')
    #绘制矢量图
    _,figs = plt.subplots(1,len(images),figsize=(12,12))
    #设置添加子图的数量、大小
    for f,img,lbl in zip(figs,images,labels):
        f.imshow(img.view(28,28).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

#从测试集中获得样本和标签
X, y = iter(test_iter).next()

true_labels = get_Fashion_MNIST_labels(y.numpy())
pred_labels = get_Fashion_MNIST_labels(net(X).argmax(dim=1).numpy())

#将真实标签和预测得到的标签加入到图像上
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

show_fashion_mnist(X[0:9], titles[0:9])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

实现效果

第一行是真实标签,第二行是识别标签


写文章不易,如果觉得有用,麻烦关注我呗~
欢迎各位关注【拇指笔记】,每天更新我的学习笔记~
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/262333?site
推荐阅读
相关标签
  

闽ICP备14008679号