当前位置:   article > 正文

计算机视觉之ResNet

resnet

1 ResNet介绍

1.1 ResNet概述

RestNet是2015年由微软团队提出的,在当时获得分类任务,目标检测,图像分割第一名。该论文的四位作者何恺明、张祥雨、任少卿和孙剑如今在人工智能领域里都是响当当的名字,当时他们都是微软亚研的一员。实验结果显示,残差网络更容易优化,并且加深网络层数有助于提高正确率。在ImageNet上使用152层的残差网络(VGG net的8倍深度,但残差网络复杂度更低)。对这些网络使用集成方法实现了3.75%的错误率。获得了ILSVRC 2015竞赛的第一名。

论文地址:原文链接      

这是一篇计算机视觉领域的经典论文。李沐曾经说过,假设你在使用卷积神经网络,有一半的可能性就是在使用 ResNet 或它的变种。ResNet 论文被引用数量突破了 10 万+。

1.2 ResNet网络结构

ResNet的经典网络结构有:ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152几种,其中,ResNet-18和ResNet-34的基本结构相同,属于相对浅层的网络,后面3种属于更深层的网络,其中RestNet50最为常用。

残差网络是为了解决深度神经网络(DNN)隐藏层过多时的网络退化问题而提出。退化(degradation)问题是指:当网络隐藏层变多时,网络的准确度达到饱和然后急剧退化,而且这个退化不是由于过拟合引起的。

假设一个网络 A,训练误差为 x。在 A 的顶部添加几个层构建网络 B,这些层的参数对于 A 的输出没有影响,我们称这些层为 C。这意味着新网络 B 的训练误差也是 x。网络 B 的训练误差不应高于 A,如果出现 B 的训练误差高于 A 的情况,则使用添加的层 C 学习恒等映射(对输入没有影响)并不是一个平凡问题。

为了解决这个问题,上图中的模块在输入和输出之间添加了一个直连路径,以直接执行映射。这时,C 只需要学习已有的输入特征就可以了。由于 C 只学习残差,该模块叫作残差模块。

此外,和当年几乎同时推出的 GoogLeNet 类似,它也在分类层之后连接了一个全局平均池化层。通过这些变化,ResNet 可以学习 152 个层的深层网络。它可以获得比 VGGNet 和 GoogLeNet 更高的准确率,同时计算效率比 VGGNet 更高。ResNet-152 可以取得 95.51% 的 top-5 准确率。

         RestNet18和RestNet50网络结构如下:

2 基于pytorch在CIFAR10数据下的RestNet50的实现

2.1 cifar-10数据集

Cifar-10 是由 Hinton 的学生 Alex Krizhevsky、Ilya Sutskever 收集的一个用于普适物体识别的计算机视觉数据集,它包含 60000 张 32 X 32 的 RGB 彩色图片,总共 10 个分类。其中,包括 50000 张用于训练集,10000 张用于测试集。

CIFAR-10数据集中一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。

CIFAR-10是一个更接近普适物体的彩色图像数据集。与MNIST数据集相比, CIFAR-10具有以下不同点:

  • CIFAR-10 是3 通道的彩色RGB 图像,而MNIST 是灰度图像。
  • CIFAR-10 的图片尺寸为32 × 32 , 而MNIST 的图片尺寸为28 × 28 ,比MNIST 稍大。

相比于手写字符,CIFAR-10含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、特征都不尽相同,这为识别带来很大困难。直接的线性模型如Softmax 在CIFAR-10 上表现得很差。

2.2 代码实现

  1. import torch
  2. from torch import nn
  3. from torch.utils.data import DataLoader
  4. from torchvision import datasets, utils
  5. from torchvision.transforms import ToTensor
  6. import matplotlib.pyplot as plt
  7. from torchvision.transforms import transforms
  8. import torch.nn.functional as F
  9. import datetime
  10. import numpy as np
  11. class Bottleneck(nn.Module):
  12. def __init__(self, in_channels, out_channels, stride=[1, 1, 1], padding=[0, 1, 0], first=False) -> None:
  13. super(Bottleneck, self).__init__()
  14. self.bottleneck = nn.Sequential(
  15. nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=padding[0], bias=False),
  16. nn.BatchNorm2d(out_channels),
  17. nn.ReLU(inplace=True),
  18. nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=padding[1], bias=False),
  19. nn.BatchNorm2d(out_channels),
  20. nn.ReLU(inplace=True),
  21. nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=stride[2], padding=padding[2], bias=False),
  22. nn.BatchNorm2d(out_channels * 4)
  23. )
  24. # 由于存在维度不一致的情况 所以分情况
  25. self.shortcut = nn.Sequential()
  26. if first:
  27. self.shortcut = nn.Sequential(
  28. # 卷积核为1 进行升降维
  29. # 注意跳变时 都是stride==2的时候 也就是每次输出信道升维的时候
  30. nn.Conv2d(in_channels, out_channels * 4, kernel_size=1, stride=stride[1], bias=False),
  31. nn.BatchNorm2d(out_channels * 4)
  32. )
  33. def forward(self, x):
  34. out = self.bottleneck(x)
  35. out += self.shortcut(x)
  36. out = F.relu(out)
  37. return out
  38. class ResNet50(nn.Module):
  39. def __init__(self, Bottleneck, num_classes=10) -> None:
  40. super(ResNet50, self).__init__()
  41. self.in_channels = 64
  42. self.conv1 = nn.Sequential(
  43. nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
  44. nn.BatchNorm2d(64),
  45. nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  46. )
  47. self.conv2 = self._make_layer(Bottleneck, 64, [[1, 1, 1]] * 3, [[0, 1, 0]] * 3)
  48. self.conv3 = self._make_layer(Bottleneck, 128, [[1, 2, 1]] + [[1, 1, 1]] * 3, [[0, 1, 0]] * 4)
  49. self.conv4 = self._make_layer(Bottleneck, 256, [[1, 2, 1]] + [[1, 1, 1]] * 5, [[0, 1, 0]] * 6)
  50. self.conv5 = self._make_layer(Bottleneck, 512, [[1, 2, 1]] + [[1, 1, 1]] * 2, [[0, 1, 0]] * 3)
  51. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  52. self.fc = nn.Linear(2048, num_classes)
  53. def _make_layer(self, block, out_channels, strides, paddings):
  54. layers = []
  55. flag = True
  56. for i in range(0, len(strides)):
  57. layers.append(block(self.in_channels, out_channels, strides[i], paddings[i], first=flag))
  58. flag = False
  59. self.in_channels = out_channels * 4
  60. return nn.Sequential(*layers)
  61. def forward(self, x):
  62. out = self.conv1(x)
  63. out = self.conv2(out)
  64. out = self.conv3(out)
  65. out = self.conv4(out)
  66. out = self.conv5(out)
  67. out = self.avgpool(out)
  68. out = out.reshape(x.shape[0], -1)
  69. out = self.fc(out)
  70. return out
  71. def get_format_time():
  72. return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  73. transform = transforms.Compose([ToTensor(),
  74. transforms.Normalize(
  75. mean=[0.5, 0.5, 0.5],
  76. std=[0.5, 0.5, 0.5]
  77. ),
  78. transforms.Resize((224, 224))
  79. ])
  80. training_data = datasets.CIFAR10(
  81. root="data",
  82. train=True,
  83. download=True,
  84. transform=transform,
  85. )
  86. testing_data = datasets.CIFAR10(
  87. root="data",
  88. train=False,
  89. download=True,
  90. transform=transform,
  91. )
  92. if __name__ == "__main__":
  93. res50 = ResNet50(Bottleneck)
  94. batch_size = 128
  95. train_loader = DataLoader(dataset=training_data, batch_size=batch_size, shuffle=True, drop_last=True)
  96. test_loader = DataLoader(dataset=testing_data, batch_size=batch_size, shuffle=True, drop_last=True)
  97. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  98. model = res50.to(device)
  99. cost = torch.nn.CrossEntropyLoss()
  100. optimizer = torch.optim.Adam(model.parameters())
  101. epochs = 20
  102. accuracy_rate = []
  103. for epoch in range(epochs):
  104. train_loss = 0.0
  105. train_correct = 0.0
  106. model.train()
  107. print(f"{get_format_time()}, train epoch: {epoch}/{epochs}")
  108. for step, (images, labels) in enumerate(train_loader, 0):
  109. images, labels = images.to(device), labels.to(device)
  110. outputs = model(images)
  111. _, predicted = torch.max(outputs.data, 1)
  112. optimizer.zero_grad()
  113. loss = cost(outputs, labels)
  114. loss.backward()
  115. optimizer.step()
  116. train_loss += loss.item()
  117. train_correct += torch.sum(predicted == labels.data)
  118. # 在测试集上进行验证
  119. model.eval()
  120. test_correct = 0
  121. test_total = 0
  122. test_loss = 0
  123. with torch.no_grad():
  124. for images, labels in test_loader:
  125. images, labels = images.to(device), labels.to(device)
  126. outputs = model(images).to(device)
  127. loss = cost(outputs, labels)
  128. _, predicted = torch.max(outputs, 1)
  129. test_total += labels.size(0)
  130. test_correct += torch.sum(predicted == labels.data)
  131. test_loss += loss.item()
  132. accuracy = 100 * test_correct / test_total
  133. accuracy_rate.append(accuracy)
  134. print("{}, Train Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Loss is::{:.4f} Test Accuracy is:{:.4f}%".format(
  135. get_format_time(),
  136. train_loss / len(training_data),
  137. 100 * train_correct / len(training_data),
  138. test_loss / len(testing_data),
  139. 100 * test_correct / len(testing_data)
  140. ))
  141. accuracy_rate = torch.tensor(accuracy_rate).detach().cpu().numpy()
  142. times = np.linspace(1, epochs, epochs)
  143. plt.xlabel('times')
  144. plt.ylabel('accuracy rate')
  145. plt.plot(times, accuracy_rate)
  146. plt.show()
  147. print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')},accuracy_rate={accuracy_rate}")

2.3 运行环境准备

(1)如果运行环境为cpu,环境准备如下:

  1. conda create -n cv python=3.9
  2. conda activate cv
  3. pip install torchvision==0.9.0
  4. pip install numpy
  5. pip install matplotlib
  6. pip install requests

(2)如果运行环境GPU,环境准备如下:

通过nvidia-smi命令,查找cuda对应的版本:

  1. Tue May 23 15:24:10 2023
  2. +-----------------------------------------------------------------------------+
  3. | NVIDIA-SMI 528.89 Driver Version: 528.89 CUDA Version: 12.0 |
  4. |-------------------------------+----------------------+----------------------+
  5. | GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |
  6. | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
  7. | | | MIG M. |
  8. |===============================+======================+======================|
  9. | 0 Tesla T4 TCC | 00000000:01:00.0 Off | 0 |
  10. | N/A 55C P8 11W / 70W | 0MiB / 15360MiB | 0% Default |
  11. | | | N/A |
  12. +-------------------------------+----------------------+----------------------+
  13. +-----------------------------------------------------------------------------+
  14. | Processes: |
  15. | GPU GI CI PID Type Process name GPU Memory |
  16. | ID ID Usage |
  17. |=============================================================================|
  18. | No running processes found |
  19. +-----------------------------------------------------------------------------+

构建运行环境,在torch的GPU版本获取对应的版本进行安装

  1. conda create -n cv python=3.9
  2. conda activate cv
  3. pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
  4. pip install numpy
  5. pip install matplotlib
  6. pip install requests

这是通过nvidia-smi命令,看到已经在GPU上运行:

  1. Tue May 23 15:25:25 2023
  2. +-----------------------------------------------------------------------------+
  3. | NVIDIA-SMI 528.89 Driver Version: 528.89 CUDA Version: 12.0 |
  4. |-------------------------------+----------------------+----------------------+
  5. | GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |
  6. | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
  7. | | | MIG M. |
  8. |===============================+======================+======================|
  9. | 0 Tesla T4 TCC | 00000000:01:00.0 Off | 0 |
  10. | N/A 56C P0 28W / 70W | 1101MiB / 15360MiB | 3% Default |
  11. | | | N/A |
  12. +-------------------------------+----------------------+----------------------+
  13. +-----------------------------------------------------------------------------+
  14. | Processes: |
  15. | GPU GI CI PID Type Process name GPU Memory |
  16. | ID ID Usage |
  17. |=============================================================================|
  18. | 0 N/A N/A 6728 C ...nda\envs\voice\python.exe 1100MiB |
  19. +-----------------------------------------------------------------------------+

2.4 运行结果展示

  1. 2023-12-22 14:44:39, train epoch: 0/20
  2. 2023-12-22 14:46:21, Train Loss is:0.0126, Train Accuracy is:40.9520%, Test Loss is::0.0116 Test Accuracy is:46.3200%
  3. 2023-12-22 14:46:21, train epoch: 1/20
  4. 2023-12-22 14:48:01, Train Loss is:0.0087, Train Accuracy is:59.5060%, Test Loss is::0.0109 Test Accuracy is:51.6700%
  5. 2023-12-22 14:48:01, train epoch: 2/20
  6. 2023-12-22 14:49:40, Train Loss is:0.0070, Train Accuracy is:68.1060%, Test Loss is::0.0072 Test Accuracy is:67.8100%
  7. 2023-12-22 14:49:40, train epoch: 3/20
  8. 2023-12-22 14:51:20, Train Loss is:0.0057, Train Accuracy is:74.2540%, Test Loss is::0.0073 Test Accuracy is:67.7400%
  9. 2023-12-22 14:51:20, train epoch: 4/20
  10. 2023-12-22 14:53:00, Train Loss is:0.0049, Train Accuracy is:77.9280%, Test Loss is::0.0061 Test Accuracy is:73.7400%
  11. 2023-12-22 14:53:00, train epoch: 5/20
  12. 2023-12-22 14:54:41, Train Loss is:0.0042, Train Accuracy is:81.3260%, Test Loss is::0.0049 Test Accuracy is:77.9900%
  13. 2023-12-22 14:54:41, train epoch: 6/20
  14. 2023-12-22 14:56:20, Train Loss is:0.0036, Train Accuracy is:83.9240%, Test Loss is::0.0047 Test Accuracy is:79.0400%
  15. 2023-12-22 14:56:20, train epoch: 7/20
  16. 2023-12-22 14:58:00, Train Loss is:0.0031, Train Accuracy is:86.0780%, Test Loss is::0.0059 Test Accuracy is:75.6300%
  17. 2023-12-22 14:58:00, train epoch: 8/20
  18. 2023-12-22 14:59:39, Train Loss is:0.0027, Train Accuracy is:87.7120%, Test Loss is::0.0048 Test Accuracy is:79.7600%
  19. 2023-12-22 14:59:39, train epoch: 9/20
  20. 2023-12-22 15:01:19, Train Loss is:0.0023, Train Accuracy is:89.3680%, Test Loss is::0.0048 Test Accuracy is:80.5800%
  21. 2023-12-22 15:01:19, train epoch: 10/20
  22. 2023-12-22 15:02:58, Train Loss is:0.0019, Train Accuracy is:91.2760%, Test Loss is::0.0044 Test Accuracy is:82.3400%
  23. 2023-12-22 15:02:58, train epoch: 11/20
  24. 2023-12-22 15:04:38, Train Loss is:0.0016, Train Accuracy is:92.4040%, Test Loss is::0.0045 Test Accuracy is:82.6400%
  25. 2023-12-22 15:04:38, train epoch: 12/20
  26. 2023-12-22 15:06:18, Train Loss is:0.0014, Train Accuracy is:93.7200%, Test Loss is::0.0053 Test Accuracy is:81.7900%
  27. 2023-12-22 15:06:18, train epoch: 13/20
  28. 2023-12-22 15:07:57, Train Loss is:0.0011, Train Accuracy is:94.7360%, Test Loss is::0.0051 Test Accuracy is:81.7700%
  29. 2023-12-22 15:07:57, train epoch: 14/20
  30. 2023-12-22 15:09:37, Train Loss is:0.0010, Train Accuracy is:95.1120%, Test Loss is::0.0062 Test Accuracy is:80.6500%
  31. 2023-12-22 15:09:37, train epoch: 15/20
  32. 2023-12-22 15:11:15, Train Loss is:0.0008, Train Accuracy is:96.1600%, Test Loss is::0.0056 Test Accuracy is:82.0300%
  33. 2023-12-22 15:11:15, train epoch: 16/20
  34. 2023-12-22 15:12:54, Train Loss is:0.0007, Train Accuracy is:96.6140%, Test Loss is::0.0055 Test Accuracy is:82.4200%
  35. 2023-12-22 15:12:54, train epoch: 17/20
  36. 2023-12-22 15:14:34, Train Loss is:0.0007, Train Accuracy is:96.8880%, Test Loss is::0.0068 Test Accuracy is:81.1300%
  37. 2023-12-22 15:14:34, train epoch: 18/20
  38. 2023-12-22 15:16:13, Train Loss is:0.0006, Train Accuracy is:97.0620%, Test Loss is::0.0062 Test Accuracy is:82.1900%
  39. 2023-12-22 15:16:13, train epoch: 19/20
  40. 2023-12-22 15:17:52, Train Loss is:0.0006, Train Accuracy is:97.4180%, Test Loss is::0.0063 Test Accuracy is:82.7800%
  41. 2023-12-22 15:17:53,accuracy_rate=[46.39423 51.752804 67.91867 67.84856 73.85818 78.11498 79.166664
  42. 75.751205 79.887825 80.70914 82.471954 82.77244 81.921074 81.90104
  43. 80.77925 82.16146 82.552086 81.26002 82.32172 82.91266 ]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/598150
推荐阅读
相关标签
  

闽ICP备14008679号