当前位置:   article > 正文

如何理解残差网络(resnet)结构和代码实现(Pytorch)笔记分享_残差结构代码

残差结构代码

在深度学习的网络中,个人认为最基础的还是残差网络,今天分享的并不是残差网络的理论部分,大家只要记住一点,残差网络的思想是贯穿后面很多网络结构之中,看懂了残差网络结构,那么后面的一些先进的网络的结构也很容易看懂。

残差网络整体结构 

一、残差块结构

 前50层所对应的残差块结构(不包含第50层)代码如下:

  1. class BasicBlock(nn.Module):
  2. expansion = 1
  3. def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):#downsample=None表示虚线的残差结构
  4. super(BasicBlock, self).__init__()
  5. self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
  6. kernel_size=3, stride=stride, padding=1, bias=False)
  7. self.bn1 = nn.BatchNorm2d(out_channel)
  8. self.relu = nn.ReLU()
  9. self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
  10. kernel_size=3, stride=1, padding=1, bias=False)
  11. self.bn2 = nn.BatchNorm2d(out_channel)
  12. self.downsample = downsample
  13. def forward(self, x):
  14. identity = x
  15. if self.downsample is not None:
  16. identity = self.downsample(x)
  17. out = self.conv1(x)
  18. out = self.bn1(out)
  19. out = self.relu(out)
  20. out = self.conv2(out)
  21. out = self.bn2(out)
  22. out += identity
  23. out = self.relu(out)
  24. return out

 后50层所对应的残差块结构(包含第50层)代码如下:

  1. class Bottleneck(nn.Module):
  2. """
  3. 注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
  4. 但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
  5. 这么做的好处是能够在top1上提升大概0.5%的准确率。
  6. 可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
  7. """
  8. expansion = 4
  9. def __init__(self, in_channel, out_channel, stride=1, downsample=None,
  10. groups=1, width_per_group=64):
  11. super(Bottleneck, self).__init__()
  12. width = int(out_channel * (width_per_group / 64.)) * groups
  13. self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
  14. kernel_size=1, stride=1, bias=False) # squeeze channels
  15. self.bn1 = nn.BatchNorm2d(width)
  16. # -----------------------------------------
  17. self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
  18. kernel_size=3, stride=stride, bias=False, padding=1)
  19. self.bn2 = nn.BatchNorm2d(width)
  20. # -----------------------------------------
  21. self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
  22. kernel_size=1, stride=1, bias=False) # unsqueeze channels
  23. self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
  24. self.relu = nn.ReLU(inplace=True)
  25. self.downsample = downsample
  26. def forward(self, x):
  27. identity = x
  28. if self.downsample is not None:
  29. identity = self.downsample(x)
  30. out = self.conv1(x)
  31. out = self.bn1(out)
  32. out = self.relu(out)
  33. out = self.conv2(out)
  34. out = self.bn2(out)
  35. out = self.relu(out)
  36. out = self.conv3(out)
  37. out = self.bn3(out)
  38. out += identity
  39. out = self.relu(out)
  40. return out

看到上面的两个残差块,初学者兴许会觉得疑惑,或者看到代码会疑惑,为什么有两种残差块呢?、其实这两种残差块是针对不同网络层数的,第一个残差结构是针对浅层的残差网络的,比如resnet18,resnet34,而第二个残差结构是针对深层的残差结构的,比如resnet50,resnet101,resnet152。

、在代码中会分别实现这两种残差块,为的就是方便更改网络的层数。对于残差块结构,一般的网络总是命名成Block。所以看代码使,要对着图来看。

其次需要注意的是,3x3卷积核一般用于降低特征图大小的,1x1卷积一般用于降低或者增加通道数的。

 二、concat和add的区别

对于初学者,看到这两个单词还是比较迷的,又或者没法理解。所以这点要注意一下,

 concat操作:一般需要特征图的大小相同,才能在对应的通道维度上拼接,比如说下图所示:

  add操作:一般需要特征图大小和通道数相同,比如下图左边两个图都是特征大小为2x2,通道数为1的,所以二者能够在对应位置相加。

 三、为什么残差边需要进行下采样

如下图。 你会发现上面的两个残差块的其中一天残差边并没有下图的1x1,128的样式,只能告诉你,这是作者默认你已经入门深度学习了,所以才没有写,我们仔细分析下面的图,首先[56,56,64]经过3x3,128,步长为2的卷积核,会变成[28,28,128],再经过3x3,128,步长为1的卷积核,会变成[28,28,128],但是却和输入的[56,56,64]大小和通道数不一致,所以[56,56,64]在残差边上进行一次3x3,128,步长为2的卷积核,从而也能得到[28,28,128],最后两个[28,28,128]进行相加。

代码如下: 

  1. import torch.nn as nn
  2. import torch
  3. class BasicBlock(nn.Module):
  4. expansion = 1
  5. def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):#downsample=None表示虚线的残差结构
  6. super(BasicBlock, self).__init__()
  7. self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
  8. kernel_size=3, stride=stride, padding=1, bias=False)
  9. self.bn1 = nn.BatchNorm2d(out_channel)
  10. self.relu = nn.ReLU()
  11. self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
  12. kernel_size=3, stride=1, padding=1, bias=False)
  13. self.bn2 = nn.BatchNorm2d(out_channel)
  14. self.downsample = nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=1,stride=2)
  15. self.bn3 = nn.BatchNorm2d(out_channel)
  16. def forward(self, x):
  17. identity = x
  18. if self.downsample is not None:
  19. identity =self.relu(self.bn3(self.downsample(x)))
  20. out = self.conv1(x)
  21. out = self.bn1(out)
  22. out = self.relu(out)
  23. out = self.conv2(out)
  24. out = self.bn2(out)
  25. out += identity
  26. out = self.relu(out)
  27. return out
  28. if __name__ == '__main__':
  29. a=torch.randn((1,64,56,56))
  30. model=BasicBlock(in_channel=64,out_channel=128,stride=2,downsample=True)
  31. out=model(a)
  32. print(out.shape)

 完整resnet网络代码如下:

  1. import torch.nn as nn
  2. import torch
  3. #下面的类是3x3 3x3的残差结构
  4. class BasicBlock(nn.Module):
  5. expansion = 1
  6. def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):#downsample=None表示虚线的残差结构
  7. super(BasicBlock, self).__init__()
  8. self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
  9. kernel_size=3, stride=stride, padding=1, bias=False)
  10. self.bn1 = nn.BatchNorm2d(out_channel)
  11. self.relu = nn.ReLU()
  12. self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
  13. kernel_size=3, stride=1, padding=1, bias=False)
  14. self.bn2 = nn.BatchNorm2d(out_channel)
  15. self.downsample = downsample
  16. def forward(self, x):
  17. identity = x
  18. if self.downsample is not None:
  19. identity = self.downsample(x)
  20. out = self.conv1(x)
  21. out = self.bn1(out)
  22. out = self.relu(out)
  23. out = self.conv2(out)
  24. out = self.bn2(out)
  25. out += identity
  26. out = self.relu(out)
  27. return out
  28. #这个表示后面50层的残差结构
  29. class Bottleneck(nn.Module):
  30. """
  31. 注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
  32. 但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
  33. 这么做的好处是能够在top1上提升大概0.5%的准确率。
  34. 可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
  35. """
  36. expansion = 4
  37. def __init__(self, in_channel, out_channel, stride=1, downsample=None,
  38. groups=1, width_per_group=64):
  39. super(Bottleneck, self).__init__()
  40. width = int(out_channel * (width_per_group / 64.)) * groups
  41. self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
  42. kernel_size=1, stride=1, bias=False) # squeeze channels
  43. self.bn1 = nn.BatchNorm2d(width)
  44. # -----------------------------------------
  45. self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
  46. kernel_size=3, stride=stride, bias=False, padding=1)
  47. self.bn2 = nn.BatchNorm2d(width)
  48. # -----------------------------------------
  49. self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
  50. kernel_size=1, stride=1, bias=False) # unsqueeze channels
  51. self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
  52. self.relu = nn.ReLU(inplace=True)
  53. self.downsample = downsample
  54. def forward(self, x):
  55. identity = x
  56. if self.downsample is not None:
  57. identity = self.downsample(x)
  58. out = self.conv1(x)
  59. out = self.bn1(out)
  60. out = self.relu(out)
  61. out = self.conv2(out)
  62. out = self.bn2(out)
  63. out = self.relu(out)
  64. out = self.conv3(out)
  65. out = self.bn3(out)
  66. out += identity
  67. out = self.relu(out)
  68. return out
  69. class ResNet(nn.Module):
  70. def __init__(self,
  71. block,
  72. blocks_num, #对于343,4,6,3
  73. num_classes=1000,
  74. include_top=True,#为了能够搭建更加复杂的网络
  75. groups=1,
  76. width_per_group=64):
  77. super(ResNet, self).__init__()
  78. self.include_top = include_top
  79. self.in_channel = 64
  80. self.groups = groups
  81. self.width_per_group = width_per_group
  82. self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
  83. padding=3, bias=False)
  84. self.bn1 = nn.BatchNorm2d(self.in_channel)
  85. self.relu = nn.ReLU(inplace=True)
  86. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  87. self.layer1 = self._make_layer(block, 64, blocks_num[0])
  88. self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
  89. self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
  90. self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
  91. if self.include_top:
  92. self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
  93. self.fc = nn.Linear(512 * block.expansion, num_classes)
  94. for m in self.modules():
  95. if isinstance(m, nn.Conv2d):
  96. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  97. def _make_layer(self, block, channel, block_num, stride=1):
  98. downsample = None
  99. if stride != 1 or self.in_channel != channel * block.expansion:
  100. downsample = nn.Sequential(
  101. nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
  102. nn.BatchNorm2d(channel * block.expansion))
  103. layers = []
  104. layers.append(block(self.in_channel,
  105. channel,
  106. downsample=downsample,
  107. stride=stride,
  108. groups=self.groups,
  109. width_per_group=self.width_per_group))
  110. self.in_channel = channel * block.expansion
  111. for _ in range(1, block_num):
  112. layers.append(block(self.in_channel,
  113. channel,
  114. groups=self.groups,
  115. width_per_group=self.width_per_group))
  116. return nn.Sequential(*layers)
  117. def forward(self, x):
  118. x = self.conv1(x)
  119. x = self.bn1(x)
  120. x = self.relu(x)
  121. x = self.maxpool(x)
  122. x = self.layer1(x)
  123. x = self.layer2(x)
  124. x = self.layer3(x)
  125. x = self.layer4(x)
  126. if self.include_top:
  127. x = self.avgpool(x)
  128. x = torch.flatten(x, 1)
  129. x = self.fc(x)
  130. return x
  131. def resnet34(num_classes=1000, include_top=True):
  132. # https://download.pytorch.org/models/resnet34-333f7ec4.pth
  133. return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
  134. def resnet50(num_classes=1000, include_top=True):
  135. # https://download.pytorch.org/models/resnet50-19c8e357.pth
  136. return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
  137. def resnet101(num_classes=1000, include_top=True):
  138. # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
  139. return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
  140. def resnext50_32x4d(num_classes=1000, include_top=True):
  141. # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
  142. groups = 32
  143. width_per_group = 4
  144. return ResNet(Bottleneck, [3, 4, 6, 3],
  145. num_classes=num_classes,
  146. include_top=include_top,
  147. groups=groups,
  148. width_per_group=width_per_group)
  149. def resnext101_32x8d(num_classes=1000, include_top=True):
  150. # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
  151. groups = 32
  152. width_per_group = 8
  153. return ResNet(Bottleneck, [3, 4, 23, 3],
  154. num_classes=num_classes,
  155. include_top=include_top,
  156. groups=groups,
  157. width_per_group=width_per_group)
  158. if __name__ == '__main__':
  159. net=resnet34()
  160. print(net)

至此网络结构说明完成!希望大家有所收获,有什么疑问的地方,欢迎大家评论!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/372451
推荐阅读
相关标签
  

闽ICP备14008679号