当前位置:   article > 正文

分类模型--ResNet系列--ResNet50_resnet50模型

resnet50模型

ResNet是什么?解决了?

Residual net(残差网络):将靠前若干层的某一层数据输出直接跳过多层引入到后面的数据层的输入部分;表明了后面的特征层的内容会有一部分由前面的某一层线性贡献。结构如下:

ResNet网络解决深度网络退化问题。采用ResNet网络结构的网络层数越深,性能越佳;如下图一显示,56层的深度网络的训练误差和测试误差反而比20层网络更大,这就是深度网络退化;而图二展示了使用ResNet结构前后,34层网络与18层网络的训练结果对比;

 ResNet原理及结构

假设我们想要网络块学习到的映射为H(x),而直接学习H(x)是很难学习到的。若我们学习另一个残差函数F(x) = H(x) - x可以很容易学习,因为此时网络块的训练目标是将F(x)逼近于0,而不是某一特定映射。因此,最后的映射H(x)就是将F(x)和x相加,H(x) = F(x) + x,如图所示。

 因此,这个网络块的输出y

y=F(x,{W_{i}})+x.

 为了相加,必须保证加号左右侧维度相同,因此可写成通式如下,Ws用于匹配维度;

y=F(x,W_{i})+W_{s} \: x.

文中提到两种维度匹配的方式:(A)用zero-paddiing增加维度;(B)用1*1卷积增加维度;

在ResNet网络中,有两种基础块,分别是BasicBlock和BotteNeck;前者用于ResNet34以下的网络,后者用于ResNet50及以上的网络;

ResNet50

其包含两个基本块,分别叫做 Conv Block Identity Block,其中Conv块输入和输出维度是不一样的,所以不能连续串联,其用于改变网络的维度;Identity块输入维度和输出维度相同,可以串联,用于加深网络;

Conv Block

在这里插入图片描述

 Identity Block

在这里插入图片描述

 整体结构如下:

在这里插入图片描述

代码

resnet50.py

  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import functional as F
  4. class ResNet50BasicBlock(nn.Module):
  5. def __init__(self, in_channel, outs, kernerl_size, stride, padding):
  6. super(ResNet50BasicBlock, self).__init__()
  7. self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernerl_size[0], stride=stride[0], padding=padding[0])
  8. self.bn1 = nn.BatchNorm2d(outs[0])
  9. self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernerl_size[1], stride=stride[0], padding=padding[1])
  10. self.bn2 = nn.BatchNorm2d(outs[1])
  11. self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernerl_size[2], stride=stride[0], padding=padding[2])
  12. self.bn3 = nn.BatchNorm2d(outs[2])
  13. def forward(self, x):
  14. out = self.conv1(x)
  15. out = F.relu(self.bn1(out))
  16. out = self.conv2(out)
  17. out = F.relu(self.bn2(out))
  18. out = self.conv3(out)
  19. out = self.bn3(out)
  20. return F.relu(out + x)
  21. class ResNet50DownBlock(nn.Module):
  22. def __init__(self, in_channel, outs, kernel_size, stride, padding):
  23. super(ResNet50DownBlock, self).__init__()
  24. # out1, out2, out3 = outs
  25. # print(outs)
  26. self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernel_size[0], stride=stride[0], padding=padding[0])
  27. self.bn1 = nn.BatchNorm2d(outs[0])
  28. self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernel_size[1], stride=stride[1], padding=padding[1])
  29. self.bn2 = nn.BatchNorm2d(outs[1])
  30. self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernel_size[2], stride=stride[2], padding=padding[2])
  31. self.bn3 = nn.BatchNorm2d(outs[2])
  32. self.extra = nn.Sequential(
  33. nn.Conv2d(in_channel, outs[2], kernel_size=1, stride=stride[3], padding=0),
  34. nn.BatchNorm2d(outs[2])
  35. )
  36. def forward(self, x):
  37. x_shortcut = self.extra(x)
  38. out = self.conv1(x)
  39. out = self.bn1(out)
  40. out = F.relu(out)
  41. out = self.conv2(out)
  42. out = self.bn2(out)
  43. out = F.relu(out)
  44. out = self.conv3(out)
  45. out = self.bn3(out)
  46. return F.relu(x_shortcut + out)
  47. class ResNet50(nn.Module):
  48. def __init__(self):
  49. super(ResNet50, self).__init__()
  50. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
  51. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  52. self.layer1 = nn.Sequential(
  53. ResNet50DownBlock(64, outs=[64, 64, 256], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  54. ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  55. ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  56. )
  57. self.layer2 = nn.Sequential(
  58. ResNet50DownBlock(256, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
  59. ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  60. ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
  61. ResNet50DownBlock(512, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0])
  62. )
  63. self.layer3 = nn.Sequential(
  64. ResNet50DownBlock(512, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
  65. ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1],
  66. padding=[0, 1, 0]),
  67. ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1],
  68. padding=[0, 1, 0]),
  69. ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  70. padding=[0, 1, 0]),
  71. ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  72. padding=[0, 1, 0]),
  73. ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  74. padding=[0, 1, 0])
  75. )
  76. self.layer4 = nn.Sequential(
  77. ResNet50DownBlock(1024, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2],
  78. padding=[0, 1, 0]),
  79. ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  80. padding=[0, 1, 0]),
  81. ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1],
  82. padding=[0, 1, 0])
  83. )
  84. self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
  85. self.fc = nn.Linear(2048, 10)
  86. def forward(self, x):
  87. out = self.conv1(x)
  88. out = self.maxpool(out)
  89. out = self.layer1(out)
  90. out = self.layer2(out)
  91. out = self.layer3(out)
  92. out = self.layer4(out)
  93. out = self.avgpool(out)
  94. out = out.reshape(x.shape[0], -1)
  95. out = self.fc(out)
  96. return out
  97. if __name__ == '__main__':
  98. x = torch.randn(2, 3, 224, 224)
  99. net = ResNet50()
  100. out = net(x)
  101. print('out.shape: ', out.shape)
  102. print(out)

main.py

  1. import torch
  2. from torch import nn, optim
  3. import torchvision.transforms as transforms
  4. from torchvision import datasets
  5. from torch.utils.data import DataLoader
  6. from resnet50 import ResNet50
  7. # 用CIFAR-10 数据集进行实验
  8. def main():
  9. batchsz = 128
  10. cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
  11. transforms.Resize((32, 32)),
  12. transforms.ToTensor(),
  13. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  14. std=[0.229, 0.224, 0.225])
  15. ]), download=True)
  16. cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
  17. cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
  18. transforms.Resize((32, 32)),
  19. transforms.ToTensor(),
  20. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  21. std=[0.229, 0.224, 0.225])
  22. ]), download=True)
  23. cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
  24. x, label = iter(cifar_train).next()
  25. print('x:', x.shape, 'label:', label.shape)
  26. device = torch.device('cuda')
  27. # model = Lenet5().to(device)
  28. model = ResNet50().to(device)
  29. criteon = nn.CrossEntropyLoss().to(device)
  30. optimizer = optim.Adam(model.parameters(), lr=1e-3)
  31. # print(model)
  32. for epoch in range(1000):
  33. model.train()
  34. for batchidx, (x, label) in enumerate(cifar_train):
  35. # [b, 3, 32, 32]
  36. # [b]
  37. x, label = x.to(device), label.to(device)
  38. logits = model(x)
  39. # logits: [b, 10]
  40. # label: [b]
  41. # loss: tensor scalar
  42. loss = criteon(logits, label)
  43. # backprop
  44. optimizer.zero_grad()
  45. loss.backward()
  46. optimizer.step()
  47. print(epoch, 'loss:', loss.item())
  48. model.eval()
  49. with torch.no_grad():
  50. # test
  51. total_correct = 0
  52. total_num = 0
  53. for x, label in cifar_test:
  54. # [b, 3, 32, 32]
  55. # [b]
  56. x, label = x.to(device), label.to(device)
  57. # [b, 10]
  58. logits = model(x)
  59. # [b]
  60. pred = logits.argmax(dim=1)
  61. # [b] vs [b] => scalar tensor
  62. correct = torch.eq(pred, label).float().sum().item()
  63. total_correct += correct
  64. total_num += x.size(0)
  65. # print(correct)
  66. acc = total_correct / total_num
  67. print(epoch, 'test acc:', acc)
  68. if __name__ == '__main__':
  69. main()

结果展示

跑了100个Epoch,标号从0-99.最后的AC结果为0.7841

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/92830
推荐阅读
相关标签
  

闽ICP备14008679号