当前位置:   article > 正文

【图像分类案例】(8) ResNet50 鸟类图像4分类,附Pytorch完整代码_鸟类分类代码

鸟类分类代码

大家好,今天和大家分享一些如何使用 Pytorch 搭建 ResNet50 卷积神经网络模型,并使用迁移学习的思想训练网络,完成鸟类图片的预测。

ResNet 的原理 TensorFlow2 实现方式可以看我之前的两篇博文,这里就不详细说明原理了。

ResNet18、34: https://blog.csdn.net/dgvv4/article/details/122396424

ResNet50: https://blog.csdn.net/dgvv4/article/details/121878494


1. 模型构建

首先导入网络构建过程中所有需要用到的工具包,本小节的所有代码写在 ResNet.py 文件中

  1. import torch
  2. from torch import nn
  3. from torchstat import stat # 查看网络参数
  4. from torchsummary import summary # 查看网络结构

1.1 构建单个残差块

残差单元的结构如下图所示,一种是基本模块,即输入特征图的尺寸和输出特征层的尺寸相同,两个特征图可以直接叠加;一种是下采样模块,即主干部分对输入特征图使用 stride=2 的下采样卷积,使得输入特征图的尺寸变成原来的一半,那么残差边部分也需要对输入特征图进行下采样操作,使得输入特征图经过残差边处理后的 shape 能和主干部分处理后的特征图 shape 相同,从而能够将残差边输出和主干输出直接叠加。

以下图基本残差块为例,先对输入图像使用 1*1 卷积下降通道数,在低维空间下使用 3*3 卷积提取特征,然后再使用 1*1 卷积上升通道数,残差连接输入和输出,将叠加后的结果进过 relu 激活函数。

代码如下:

  1. # -------------------------------------------- #
  2. #(1)残差单元
  3. # x--> 卷积 --> bn --> relu --> 卷积 --> bn --> 输出
  4. # |---------------Identity(短接)----------------|
  5. '''
  6. in_channel 输入特征图的通道数
  7. out_channel 第一次卷积输出特征图的通道数
  8. stride=1 卷积块中3*3卷积的步长
  9. downsample 是否下采样
  10. '''
  11. # -------------------------------------------- #
  12. class Bottleneck(nn.Module):
  13. # 最后一个1*1卷积下降的通道数
  14. expansion = 4
  15. # 初始化
  16. def __init__(self, in_channel, out_channel, stride=1, downsample=None):
  17. # 继承父类初始化方法
  18. super(Bottleneck, self).__init__()
  19. # 属性分配
  20. # 1*1卷积下降通道,padding='same',若stride=1,则[b,in_channel,h,w]==>[b,out_channel,h,w]
  21. self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
  22. kernel_size=1, stride=1, padding=0, bias=False)
  23. # BN层是计算特征图在每个channel上面的均值和方差,需要给出输出通道数
  24. self.bn1 = nn.BatchNorm2d(out_channel)
  25. # relu激活, inplace=True节约内存
  26. self.relu = nn.ReLU(inplace=True)
  27. # 3*3卷积提取特征,[b,out_channel,h,w]==>[b,out_channel,h,w]
  28. self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
  29. kernel_size=3, stride=stride, padding=1, bias=False)
  30. # BN层, 有bn层就不需要bias偏置
  31. self.bn2 = nn.BatchNorm2d(out_channel)
  32. # 1*1卷积上升通道 [b,out_channel,h,w]==>[b,out_channel*expansion,h,w]
  33. self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
  34. kernel_size=1, stride=1, padding=0, bias=False)
  35. # BN层,对out_channel*expansion标准化
  36. self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
  37. # 记录是否需要下采样, 下采样就是第一个卷积层的步长=2,输入和输出的图像的尺寸不一致
  38. self.downsample = downsample
  39. # 前向传播
  40. def forward(self, x):
  41. # 残差边
  42. identity = x
  43. # 如果第一个卷积层stride=2下采样了,那么残差边也需要下采样
  44. if self.downsample is not None:
  45. # 下采样方法
  46. identity = self.downsample(x)
  47. # 主干部分
  48. x = self.conv1(x)
  49. x = self.bn1(x)
  50. x = self.relu(x)
  51. x = self.conv2(x)
  52. x = self.bn2(x)
  53. x = self.relu(x)
  54. x = self.conv3(x)
  55. x = self.bn3(x)
  56. # 残差连接
  57. x = x + identity
  58. # relu激活
  59. x = self.relu(x)
  60. return x # 输出残差单元的结果

1.2 构建网络

我们已经成功构建完单个残差单元的类,而残差结构就是由多个残差单元堆叠而来的,ResnNet50 中有 4 组残差结构,每个残差结构分别堆叠了 3,4,6,3 个残差单元,如下图所示。

第一个残差结构中的第一个残差单元只需要调整输入特征图的通道数,不需要改变特征图的尺寸;而其他三个的残差结构的第一个残差单元不仅需要对输入特征图调整通道数,还要对输入特征图进行下采样操作

代码如下:

  1. # -------------------------------------------- #
  2. #(2)网络构建
  3. '''
  4. block: 残差单元
  5. blocks_num: 每个残差结构使用残差单元的数量
  6. num_classes: 分类数量
  7. include_top: 是否包含分类层(全连接)
  8. '''
  9. # -------------------------------------------- #
  10. class ResNet(nn.Module):
  11. # 初始化
  12. def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
  13. # 继承父类初始化方法
  14. super(ResNet, self).__init__()
  15. # 属性分配
  16. self.include_top = include_top
  17. self.in_channel = 64 # 第一个卷积层的输出通道数
  18. # 7*7卷积下采样层处理输入图像 [b,3,h,w]==>[b,64,h,w]
  19. self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channel,
  20. kernel_size=7, stride=2, padding=3, bias=False)
  21. # BN对每个通道做标准化
  22. self.bn1 = nn.BatchNorm2d(self.in_channel)
  23. # relu激活函数
  24. self.relu = nn.ReLU(inplace=True)
  25. # 3*3最大池化层 [b,64,h,w]==>[b,64,h//2,w//2]
  26. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  27. # 残差卷积块
  28. # 第一个残差结构不需要下采样只需要调整通道
  29. self.layer1 = self._make_layer(block, 64, blocks_num[0])
  30. # 下面的残差结构的第一个残差单元需要进行下采样
  31. self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
  32. self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
  33. self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
  34. # 分类层
  35. if self.include_top:
  36. # 自适应全局平均池化,无论输入特征图的shape是多少,输出特征图的(h,w)==(1,1)
  37. self.avgpool = nn.AdaptiveAvgPool2d((1,1)) # output
  38. # 全连接分类
  39. self.fc = nn.Linear(512*block.expansion, num_classes)
  40. # 卷积层权重初始化
  41. for m in self.modules():
  42. if isinstance(m, nn.Conv2d):
  43. nn.init.kaiming_normal_(m.weight, mode='fan_out')
  44. # 残差结构
  45. '''
  46. block: 代表残差单元
  47. channel: 残差结构中第一个卷积层的输出通道数
  48. block_num: 代表一个残差结构包含多少个残差单元
  49. stride: 是否下采样stride=2
  50. '''
  51. def _make_layer(self, block, channel, block_num, stride=1):
  52. # 是否需要进行下采样
  53. downsample = None
  54. # 如果stride=2或者残差单元的输入和输出通道数不一致
  55. # 就对残差单元的shortcut部分执行下采样操作
  56. if stride != 1 or self.in_channel != channel * block.expansion:
  57. # 残差边需要下采样
  58. downsample = nn.Sequential(
  59. # 对于第一个残差单元的残差边部分只需要调整通道
  60. nn.Conv2d(in_channels=self.in_channel, out_channels=channel*block.expansion,
  61. kernel_size=1, stride=stride, bias=False),
  62. nn.BatchNorm2d(channel*block.expansion))
  63. # 一个残差结构堆叠多个残差单元
  64. layers = []
  65. # 先堆叠第一个残差单元,因为这个需要下采样
  66. layers.append(block(self.in_channel, channel, stride=stride, downsample=downsample))
  67. # 获得第一个残差单元的输出特征图个数, 作为第二个残差单元的输入
  68. self.in_channel = channel * block.expansion
  69. # 堆叠剩下的残差单元,此时的shortcut部分不需要下采样
  70. for _ in range(1, block_num):
  71. layers.append(block(self.in_channel, channel))
  72. # 返回构建好了的残差结构
  73. return nn.Sequential(*layers) # *代表将layers以非关键字参数的形式返还
  74. # 前向传播
  75. def forward(self, x):
  76. # 输入层
  77. x = self.conv1(x)
  78. x = self.bn1(x)
  79. x = self.relu(x)
  80. x = self.maxpool(x)
  81. # 残差结构
  82. x = self.layer1(x)
  83. x = self.layer2(x)
  84. x = self.layer3(x)
  85. x = self.layer4(x)
  86. # 分类层
  87. if self.include_top:
  88. # 全局平均池化
  89. x = self.avgpool(x)
  90. # 打平
  91. x = torch.flatten(x, 1)
  92. # 全连接分类
  93. x = self.fc(x)
  94. return x

1.3 查看网络结构

 [3,4,6,3] 代表四个残差结构中分别堆叠了多少个残差单元,include_top=True 代表包含网络的分类层,默认是1000个分类,即全连接层输出预测结果。网络结构如下图所示:

  1. # 构建resnet50
  2. def resnet50(num_classes=1000, include_top=True):
  3. return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, include_top=include_top)
  4. if __name__ == '__main__':
  5. # 接收网络模型
  6. model = resnet50()
  7. # print(model)
  8. # 查看网络参数量,不需要指定输入特征图像的batch维度
  9. stat(model, input_size=(3,224,224))
  10. # 查看网络结构及参数
  11. summary(model, input_size=[(3,224,224)], device='cpu')

网络的参数量和计算量如下:

  1. ================================================================
  2. Total params: 25,557,032
  3. Trainable params: 25,557,032
  4. Non-trainable params: 0
  5. ----------------------------------------------------------------
  6. Input size (MB): 0.57
  7. Forward/backward pass size (MB): 286.56
  8. Params size (MB): 97.49
  9. Estimated Total Size (MB): 384.62
  10. ----------------------------------------------------------------

2. 网络训练

2.1 文件配置

首先我们需要将接下来所有用到的文件包,文件路径,先写好了方便统一管理。使用迁移学习的方法训练网络。

  1. import torch
  2. from torch import nn, optim
  3. from torchvision import datasets, transforms
  4. from torch.utils.data import DataLoader
  5. from ResNet import resnet50 # 从自定义的ResNet.py文件中导入resnet50这个函数
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. # -------------------------------------------------- #
  9. #(0)参数设置
  10. # -------------------------------------------------- #
  11. batch_size = 32 # 每个step训练32张图片
  12. epochs = 10 # 共训练10次
  13. # -------------------------------------------------- #
  14. #(1)文件配置
  15. # -------------------------------------------------- #
  16. # 数据集文件夹位置
  17. filepath = 'D:/deeplearning/test/数据集/4种鸟分类/new_data/'
  18. # 权重文件位置
  19. weightpath = 'D:/deeplearning/imgnet/pytorchimgnet/pretrained_weights/resnet50.pth'
  20. # 权重保存文件夹路径
  21. savepath = 'D:/deeplearning/imgnet/pytorchimgnet/save_weights/'
  22. # 获取GPU设备
  23. if torch.cuda.is_available(): # 如果有GPU就用,没有就用CPU
  24. device = torch.device('cuda:0')
  25. else:
  26. device = torch.device('cpu')

2.2 构造数据集

首先定义训练集和验证集的数据预处理函数。将输入图像的尺寸变成模型要求的 224*224 大小,然后再将像素值类型从 numpy 变成 tensor 类型,并归一化处理,像素值大小从 [0,255] 变换到 [0,1],再调整输入图像的维度,从 [h,w,c] 变成 [c,h,w];接着对图像的每个颜色通道做标准化处理使像素值满足以0.5为均值,0.5为方差的正态分布

预处理之后就构造训练集和验证集,指定 batch_size=32,代表训练时每个 step 训练32张图片

  1. # -------------------------------------------------- #
  2. #(2)构造数据集
  3. # -------------------------------------------------- #
  4. # 训练集的数据预处理
  5. transform_train = transforms.Compose([
  6. # 数据增强,随机裁剪224*224大小
  7. transforms.RandomResizedCrop(224),
  8. # 数据增强,随机水平翻转
  9. transforms.RandomHorizontalFlip(),
  10. # 数据变成tensor类型,像素值归一化,调整维度[h,w,c]==>[c,h,w]
  11. transforms.ToTensor(),
  12. # 对每个通道的像素进行标准化,给出每个通道的均值和方差
  13. transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))])
  14. # 验证集的数据预处理
  15. transform_val = transforms.Compose([
  16. # 将输入图像大小调整为224*224
  17. transforms.Resize((224,224)),
  18. transforms.ToTensor(),
  19. transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))])
  20. # 读取训练集并预处理
  21. train_dataset = datasets.ImageFolder(root=filepath + 'train', # 训练集图片所在的文件夹
  22. transform = transform_train) # 训练集的预处理方法
  23. # 读取验证集并预处理
  24. val_dataset = datasets.ImageFolder(root=filepath + 'val', # 验证集图片所在的文件夹
  25. transform = transform_val) # 验证集的预处理方法
  26. # 查看训练集和验证集的图片数量
  27. train_num = len(train_dataset)
  28. val_num = len(val_dataset)
  29. print('train_num:', train_num, 'val_num:', val_num) # 453, 112
  30. # 查看图像类别及其对应的索引
  31. class_dict = train_dataset.class_to_idx
  32. print(class_dict) # {'Bananaquit': 0, 'Black Skimmer': 1, 'Black Throated Bushtiti': 2, 'Cockatoo': 3}
  33. # 将类别名称保存在列表中
  34. class_names = list(class_dict.keys())
  35. # 构造训练集
  36. train_loader = DataLoader(dataset=train_dataset, # 接收训练集
  37. batch_size=batch_size, # 训练时每个step处理32张图
  38. shuffle=True, # 打乱每个batch
  39. num_workers=0) # 加载数据时的线程数量,windows环境下只能=0
  40. # 构造验证集
  41. val_loader = DataLoader(dataset=val_dataset,
  42. batch_size=batch_size,
  43. shuffle=False,
  44. num_workers=0)

2.3 数据可视化

接下来查看一下数据集中构造的图片和标签是什么样的,这里要注意的是在预处理时已经对整个数据集做了归一化和以 0.5 为均值 0.5 为方差做的标准化这里需要做一次反标准化,img = img / 2 + 0.5,变会归一化之后的结果不然画出来的图太抽象了。

标准化: img = \frac{img-0.5}{0.5}        反标准化: img = img * 0.5 + 0.5

  1. # -------------------------------------------------- #
  2. #(3)数据可视化
  3. # -------------------------------------------------- #
  4. # 取出一个batch的训练集,返回图片及其标签
  5. train_img, train_label = iter(train_loader).next()
  6. # 查看shape, img=[32,3,224,224], label=[32]
  7. print(train_img.shape, train_label.shape)
  8. # 从一个batch中取出前9张图片
  9. img = train_img[:9] # [9, 3, 224, 224]
  10. # 将图片反标准化,像素变到0-1之间
  11. img = img / 2 + 0.5
  12. # tensor类型变成numpy类型
  13. img = img.numpy()
  14. class_label = train_label.numpy()
  15. # 维度重排 [b,c,h,w]==>[b,h,w,c]
  16. img = np.transpose(img, [0,2,3,1])
  17. # 创建画板
  18. plt.figure()
  19. # 绘制四张图片
  20. for i in range(img.shape[0]):
  21. plt.subplot(3,3,i+1)
  22. plt.imshow(img[i])
  23. plt.xticks([]) # 不显示x轴刻度
  24. plt.yticks([]) # 不显示y轴刻度
  25. plt.title(class_names[class_label[i]]) # 图片对应的类别
  26. plt.tight_layout() # 轻量化布局
  27. plt.show()

绘制前9张图片及其标签:


2.4 模型构建

首先导入我们构建的 resnet50 网络模型,它默认有 1000 个分类,也就是网络的最后一个全连接层的输出结果和我们当前的任务不一样。

通过 torch.load() 将权重文件加载到内存中,再通过 net.load_state_dict() 将网络的每一层权重加载上来。此时的全连接层是1000分类,因此我们将它修改为4分类的,通过 net.fc.in_features 获取最后一个全连接层的输入通道数,然后再重写这个全连接层 net.fc = nn.Linear(in_channel, 4) 将其输出神经元个数改成4个

  1. # -------------------------------------------------- #
  2. #(4)加载模型
  3. # -------------------------------------------------- #
  4. # 1000分类层
  5. net = resnet50(num_classes=1000, include_top=True)
  6. # 加载预训练权重
  7. net.load_state_dict(torch.load(weightpath, map_location=device))
  8. # 为网络重写分类层
  9. in_channel = net.fc.in_features # 2048
  10. net.fc = nn.Linear(in_channel, 4) # [b,2048]==>[b,4]
  11. # 将模型搬运到GPU上
  12. net.to(device)
  13. # 定义交叉熵损失
  14. loss_function = nn.CrossEntropyLoss()
  15. # 定义优化器
  16. optimizer = optim.Adam(net.parameters(), lr=0.002)
  17. # 保存准确率最高的一次迭代
  18. best_acc = 0.0

2.5 训练与验证

这里要注意的就是网络训练和测试的模式不一样,训练时 Dropout 层随机杀死神经元,BN层计算在 batch 维度上计算均值和方差验证时 Dropout 层不起作用,BN 层去整个训练集计算得到的均值和方差。通过 net.train() 和 net.eval()切换训练和验证模式

  1. # -------------------------------------------------- #
  2. #(5)网络训练
  3. # -------------------------------------------------- #
  4. for epoch in range(epochs):
  5. print('-'*30, '\n', 'epoch:', epoch)
  6. # 将模型设置为训练模型, dropout层和BN层只在训练时起作用
  7. net.train()
  8. # 计算训练一个epoch的总损失
  9. running_loss = 0.0
  10. # 每个step训练一个batch
  11. for step, data in enumerate(train_loader):
  12. # data中包含图像及其对应的标签
  13. images, labels = data
  14. # 梯度清零,因为每次计算梯度是一个累加
  15. optimizer.zero_grad()
  16. # 前向传播
  17. outputs = net(images.to(device))
  18. # 计算预测值和真实值的交叉熵损失
  19. loss = loss_function(outputs, labels.to(device))
  20. # 梯度计算
  21. loss.backward()
  22. # 权重更新
  23. optimizer.step()
  24. # 累加每个step的损失
  25. running_loss += loss.item()
  26. # 打印每个step的损失
  27. print(f'step:{step} loss:{loss}')
  28. # -------------------------------------------------- #
  29. #(6)网络验证
  30. # -------------------------------------------------- #
  31. net.eval() # 切换为验证模型,BN和Dropout不起作用
  32. acc = 0.0 # 验证集准确率
  33. with torch.no_grad(): # 下面不进行梯度计算
  34. # 每次验证一个batch
  35. for data_test in val_loader:
  36. # 获取验证集的图片和标签
  37. test_images, test_labels = data_test
  38. # 前向传播
  39. outputs = net(test_images.to(device))
  40. # 预测分数的最大值
  41. predict_y = torch.max(outputs, dim=1)[1]
  42. # 累加每个step的准确率
  43. acc += (predict_y == test_labels.to(device)).sum().item()
  44. # 计算所有图片的平均准确率
  45. acc_test = acc / val_num
  46. # 打印每个epoch的训练损失和验证准确率
  47. print(f'total_train_loss:{running_loss/step}, total_test_acc:{acc_test}')
  48. # -------------------------------------------------- #
  49. #(7)权重保存
  50. # -------------------------------------------------- #
  51. # 保存最好的准确率的权重
  52. if acc_test > best_acc:
  53. # 更新最佳的准确率
  54. best_acc = acc_test
  55. # 保存的权重名称
  56. savename = savepath + 'resnet50.pth'
  57. # 保存当前权重
  58. torch.save(net.state_dict(), savename)

训练过程如下:

  1. ------------------------------
  2. epoch: 9
  3. step:0 loss:0.6036051511764526
  4. step:1 loss:0.9885318279266357
  5. step:2 loss:0.5862273573875427
  6. step:3 loss:0.5116483569145203
  7. step:4 loss:0.5162124633789062
  8. step:5 loss:0.5078244805335999
  9. step:6 loss:0.40748751163482666
  10. step:7 loss:0.477965384721756
  11. step:8 loss:0.5453536510467529
  12. step:9 loss:0.4567239284515381
  13. step:10 loss:0.28334566950798035
  14. step:11 loss:0.4871848225593567
  15. step:12 loss:0.47422197461128235
  16. step:13 loss:0.5443617105484009
  17. step:14 loss:0.5486546754837036
  18. total_train_loss:0.5670963547059468, total_test_acc:0.8660714285714286

3. 预测阶段

3.1 文件配置

首先将工具包和文件路径先写好,便于统一管理

  1. import torch
  2. from torchvision import transforms
  3. from PIL import Image
  4. from ResNet import resnet50
  5. import matplotlib.pyplot as plt
  6. # -------------------------------------------------- #
  7. #(0)参数设置
  8. # -------------------------------------------------- #
  9. # 图片文件路径
  10. img_path = 'D:/deeplearning/imgnet/pytorchimgnet/img/Black_Throated_Bushtiti.jpg'
  11. # 权重参数路径
  12. weights_path = 'D:/deeplearning/imgnet/pytorchimgnet/save_weights/resnet50.pth'
  13. # 预测索引对应的类别名称
  14. class_names=['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
  15. # 获取GPU设备
  16. if torch.cuda.is_available(): # 如果有GPU就用,没有就用CPU
  17. device = torch.device('cuda:0')
  18. else:
  19. device = torch.device('cpu')

3.2 数据预处理和可视化

这里采用和验证集相同的预处理方法,由于只预测一张图像,因此在送入网络之前需要给图像增加 batch 维度,[c,h,w]==>[b,c,h,w]

  1. # -------------------------------------------------- #
  2. #(1)数据加载
  3. # -------------------------------------------------- #
  4. # 预处理函数
  5. data_transform = transforms.Compose([
  6. # 将输入图像的尺寸变成224*224
  7. transforms.Resize((224,224)),
  8. # 数据变成tensor类型,像素值归一化,调整维度[h,w,c]==>[c,h,w]
  9. transforms.ToTensor(),
  10. # 对每个通道的像素进行标准化,给出每个通道的均值和方差
  11. transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))])
  12. # 读取图片
  13. frame = Image.open(img_path)
  14. # 展示图片
  15. plt.imshow(frame)
  16. plt.title('Black_Throated_Bushtiti')
  17. plt.show()
  18. # 数据预处理
  19. img = data_transform(frame)
  20. # 给图像增加batch维度 [c,h,w]==>[b,c,h,w]
  21. img = torch.unsqueeze(img, dim=0)

需要预测的图像及其标签


3.3 图片预测

在网络前向传播之前将模型设置为验证模式 model.eval()只做前向传播的操作,不进行梯度更新操作 with torch.no_grad(): 不计算梯度。

经过前向传播后,图像的shape变成 [b,4],由于这里的 batch_size=1,因此将可以将batch维度挤压掉,得到由4个元素组成的向量,代表图片属于四个类别的分数。

  1. # -------------------------------------------------- #
  2. #(2)图像预测
  3. # -------------------------------------------------- #
  4. # 加载模型
  5. model = resnet50(num_classes=4, include_top=True)
  6. # 加载权重文件
  7. model.load_state_dict(torch.load(weights_path, map_location=device))
  8. # 模型切换成验证模式,dropout和bn切换形式
  9. model.eval()
  10. # 前向传播过程中不计算梯度
  11. with torch.no_grad():
  12. # 前向传播
  13. outputs = model(img)
  14. # 只有一张图就挤压掉batch维度
  15. outputs = torch.squeeze(outputs)
  16. # 计算图片属于4个类别的概率
  17. predict = torch.softmax(outputs, dim=0)
  18. # 得到类别索引
  19. predict_cla = torch.argmax(predict).numpy()
  20. # 获取最大预测类别概率
  21. predict_score = round(torch.max(predict).item(), 4)
  22. # 获取预测类别的名称
  23. predict_name = class_names[predict_cla]
  24. # 展示预测结果
  25. plt.imshow(frame)
  26. plt.title('class: '+str(predict_name)+'\n score: '+str(predict_score))
  27. plt.show()

预测结果如图:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/blog/article/detail/92812
推荐阅读
相关标签
  

闽ICP备14008679号