当前位置:   article > 正文

SENet代码复现+超详细注释(PyTorch)

senet

卷积网络中通道注意力经常用到SENet模块,来增强网络模型在通道权重的选择能力,进而提点。关于SENet的原理和具体细节,我们在上一篇已经详细的介绍了:经典神经网络论文超详细解读(七)——SENet(注意力机制)学习笔记(翻译+精读+代码复现)

接下来我们来复现一下代码。

因为SENet不是一个全新的网络模型,而是相当于提出了一个即插即用的高性能小插件,所以代码实现也是比较简单的。本文是在ResNet基础上加入SEblock模块进行实现ResNet_SE50。


 一、SENet结构组成介绍

 上图为一个SEblock,由SEblock块构成的网络叫做SENet;可以基于原生网络,添加SEblock块构成SE-NameNet,如基于AlexNet等添加SE结构,称作SE-AlexNet、SE-ResNet等

SE块与先进的架构Inception、ResNet的结合效果


 

原理:通过一个全局平均池化层加两个全连接层以及全连接层对应激活【ReLU和sigmoid】组成的结构输出和输入特征同样数目的权重值,也就是每个特征通道的权重系数,学习一个通道的注意力出来,用于决定哪些通道应该重点提取特征,哪些部分放弃。

 SE块详细过程

1.首先由 Inception结构 或 ResNet结构处理后的C×W×H特征图开始,通过Squeeze操作对特征图进行全局平均池化(GAP),得到1×1×C 的特征向量

2.紧接着两个 FC 层组成一个 Bottleneck 结构去建模通道间的相关性:

  (1)经过第一个FC层,将C个通道变成 C/ r​ ,减少参数量,然后通过ReLU的非线性激活,到达第二个FC层

  (2)经过第二个FC层,再将特征通道数恢复到C个,得到带有注意力机制的权重参数

3.最后经过Sigmoid激活函数,最后通过一个 Scale 的操作来将归一化后的权重加权到每个通道的特征上。


  二、SEblock的具体介绍

 Sequeeze:Fsq操作就是使用通道的全局平均池化,将包含全局信息的W×H×C 的特征图直接压缩成一个1×1×C的特征向量,即将每个二维通道变成一个具有全局感受野的数值,此时1个像素表示1个通道,屏蔽掉空间上的分布信息,更好的利用通道间的相关性。
具体操作:对原特征图50×512×7×7进行全局平均池化,然后得到了一个50×512×1×1大小的特征图,这个特征图具有全局感受野。


Excitation :基于特征通道间的相关性,每个特征通道生成一个权重,用来代表特征通道的重要程度。由原本全为白色的C个通道的特征,得到带有不同深浅程度的颜色的特征向量,也就是不同的重要程度。

具体操作:输出的50×512×1×1特征图,经过两个全连接层,最后用一 个类似于循环神经网络中门控机制,通过参数来为每个特征通道生成权重,参数被学习用来显式地建模特征通道间的相关性(论文中使用的是sigmoid)。50×512×1×1变成50×512 / 16×1×1,最后再还原回来:50×512×1×1


Reweight:将Excitation输出的权重看做每个特征通道的重要性,也就是对于U每个位置上的所有H×W上的值都乘上对应通道的权值,完成对原始特征的重校准。

具体操作:50×512×1×1通过expand_as得到50×512×7×7, 完成在通道维度上对原始特征的重标定,并作为下一级的输入数据。


三、PyTorch代码实现

(1)SEblock搭建

全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid

  1. '''-------------一、SE模块-----------------------------'''
  2. #全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
  3. class SE_Block(nn.Module):
  4. def __init__(self, inchannel, ratio=16):
  5. super(SE_Block, self).__init__()
  6. # 全局平均池化(Fsq操作)
  7. self.gap = nn.AdaptiveAvgPool2d((1, 1))
  8. # 两个全连接层(Fex操作)
  9. self.fc = nn.Sequential(
  10. nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
  11. nn.ReLU(),
  12. nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
  13. nn.Sigmoid()
  14. )
  15. def forward(self, x):
  16. # 读取批数据图片数量及通道数
  17. b, c, h, w = x.size()
  18. # Fsq操作:经池化后输出b*c的矩阵
  19. y = self.gap(x).view(b, c)
  20. # Fex操作:经全连接层输出(b,c,1,1)矩阵
  21. y = self.fc(y).view(b, c, 1, 1)
  22. # Fscale操作:将得到的权重乘以原来的特征图x
  23. return x * y.expand_as(x)

(2)将SEblock嵌入残差模块

SEblock可以灵活的加入到resnet等相关完整模型中,通常加在残差之前。【因为激活是sigmoid原因,存在梯度弥散问题,所以尽量不放到主信号通道去,即使本个残差模块有弥散问题,以不至于影响整个网络模型】

 这里我们将SE模块分别嵌入ResNet的BasicBlock和Bottleneck中,得到 SEBasicBlock和SEBottleneck(具体解释可以看我之前写的ResNet代码复现+超详细注释(PyTorch)

BasicBlock模块

  1. '''-------------二、BasicBlock模块-----------------------------'''
  2. # 左侧的 residual block 结构(18-layer、34-layer)
  3. class BasicBlock(nn.Module):
  4. expansion = 1
  5. def __init__(self, inchannel, outchannel, stride=1):
  6. super(BasicBlock, self).__init__()
  7. self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=3,
  8. stride=stride, padding=1, bias=False)
  9. self.bn1 = nn.BatchNorm2d(outchannel)
  10. self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3,
  11. stride=1, padding=1, bias=False)
  12. self.bn2 = nn.BatchNorm2d(outchannel)
  13. # SE_Block放在BN之后,shortcut之前
  14. self.SE = SE_Block(outchannel)
  15. self.shortcut = nn.Sequential()
  16. if stride != 1 or inchannel != self.expansion*outchannel:
  17. self.shortcut = nn.Sequential(
  18. nn.Conv2d(inchannel, self.expansion*outchannel,
  19. kernel_size=1, stride=stride, bias=False),
  20. nn.BatchNorm2d(self.expansion*outchannel)
  21. )
  22. def forward(self, x):
  23. out = F.relu(self.bn1(self.conv1(x)))
  24. out = self.bn2(self.conv2(out))
  25. SE_out = self.SE(out)
  26. out = out * SE_out
  27. out += self.shortcut(x)
  28. out = F.relu(out)
  29. return out

Bottleneck模块 

  1. '''-------------三、Bottleneck模块-----------------------------'''
  2. # 右侧的 residual block 结构(50-layer、101-layer、152-layer)
  3. class Bottleneck(nn.Module):
  4. expansion = 4
  5. def __init__(self, inchannel, outchannel, stride=1):
  6. super(Bottleneck, self).__init__()
  7. self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=1, bias=False)
  8. self.bn1 = nn.BatchNorm2d(outchannel)
  9. self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3,
  10. stride=stride, padding=1, bias=False)
  11. self.bn2 = nn.BatchNorm2d(outchannel)
  12. self.conv3 = nn.Conv2d(outchannel, self.expansion*outchannel,
  13. kernel_size=1, bias=False)
  14. self.bn3 = nn.BatchNorm2d(self.expansion*outchannel)
  15. # SE_Block放在BN之后,shortcut之前
  16. self.SE = SE_Block(self.expansion*outchannel)
  17. self.shortcut = nn.Sequential()
  18. if stride != 1 or inchannel != self.expansion*outchannel:
  19. self.shortcut = nn.Sequential(
  20. nn.Conv2d(inchannel, self.expansion*outchannel,
  21. kernel_size=1, stride=stride, bias=False),
  22. nn.BatchNorm2d(self.expansion*outchannel)
  23. )
  24. def forward(self, x):
  25. out = F.relu(self.bn1(self.conv1(x)))
  26. out = F.relu(self.bn2(self.conv2(out)))
  27. out = self.bn3(self.conv3(out))
  28. SE_out = self.SE(out)
  29. out = out * SE_out
  30. out += self.shortcut(x)
  31. out = F.relu(out)
  32. return out

(3)搭建SE_ResNet结构

  1. '''-------------四、搭建SE_ResNet结构-----------------------------'''
  2. class SE_ResNet(nn.Module):
  3. def __init__(self, block, num_blocks, num_classes=10):
  4. super(SE_ResNet, self).__init__()
  5. self.in_planes = 64
  6. self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
  7. stride=1, padding=1, bias=False) # conv1
  8. self.bn1 = nn.BatchNorm2d(64)
  9. self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) # conv2_x
  10. self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) # conv3_x
  11. self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) # conv4_x
  12. self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # conv5_x
  13. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  14. self.linear = nn.Linear(512 * block.expansion, num_classes)
  15. def _make_layer(self, block, planes, num_blocks, stride):
  16. strides = [stride] + [1]*(num_blocks-1)
  17. layers = []
  18. for stride in strides:
  19. layers.append(block(self.in_planes, planes, stride))
  20. self.in_planes = planes * block.expansion
  21. return nn.Sequential(*layers)
  22. def forward(self, x):
  23. x = F.relu(self.bn1(self.conv1(x)))
  24. x = self.layer1(x)
  25. x = self.layer2(x)
  26. x = self.layer3(x)
  27. x = self.layer4(x)
  28. x = self.avgpool(x)
  29. x = torch.flatten(x, 1)
  30. out = self.linear(x)
  31. return out

(4)网络模型的创建和测试

网络模型创建打印 SE_ResNet50

  1. # test()
  2. if __name__ == '__main__':
  3. model = SE_ResNet50()
  4. print(model)
  5. input = torch.randn(1, 3, 224, 224)
  6. out = model(input)
  7. print(out.shape)

打印模型如下

  1. SE_ResNet(
  2. (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  3. (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  4. (layer1): Sequential(
  5. (0): Bottleneck(
  6. (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  7. (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  8. (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  9. (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  10. (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  11. (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  12. (SE): SE_Block(
  13. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  14. (fc): Sequential(
  15. (0): Linear(in_features=256, out_features=16, bias=False)
  16. (1): ReLU()
  17. (2): Linear(in_features=16, out_features=256, bias=False)
  18. (3): Sigmoid()
  19. )
  20. )
  21. (shortcut): Sequential(
  22. (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  23. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  24. )
  25. )
  26. (1): Bottleneck(
  27. (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  28. (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  29. (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  30. (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  31. (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  32. (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  33. (SE): SE_Block(
  34. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  35. (fc): Sequential(
  36. (0): Linear(in_features=256, out_features=16, bias=False)
  37. (1): ReLU()
  38. (2): Linear(in_features=16, out_features=256, bias=False)
  39. (3): Sigmoid()
  40. )
  41. )
  42. (shortcut): Sequential()
  43. )
  44. (2): Bottleneck(
  45. (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  46. (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  47. (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  48. (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  49. (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  50. (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  51. (SE): SE_Block(
  52. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  53. (fc): Sequential(
  54. (0): Linear(in_features=256, out_features=16, bias=False)
  55. (1): ReLU()
  56. (2): Linear(in_features=16, out_features=256, bias=False)
  57. (3): Sigmoid()
  58. )
  59. )
  60. (shortcut): Sequential()
  61. )
  62. )
  63. (layer2): Sequential(
  64. (0): Bottleneck(
  65. (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  66. (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  67. (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  68. (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  69. (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  70. (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  71. (SE): SE_Block(
  72. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  73. (fc): Sequential(
  74. (0): Linear(in_features=512, out_features=32, bias=False)
  75. (1): ReLU()
  76. (2): Linear(in_features=32, out_features=512, bias=False)
  77. (3): Sigmoid()
  78. )
  79. )
  80. (shortcut): Sequential(
  81. (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
  82. (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  83. )
  84. )
  85. (1): Bottleneck(
  86. (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  87. (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  88. (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  89. (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  90. (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  91. (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  92. (SE): SE_Block(
  93. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  94. (fc): Sequential(
  95. (0): Linear(in_features=512, out_features=32, bias=False)
  96. (1): ReLU()
  97. (2): Linear(in_features=32, out_features=512, bias=False)
  98. (3): Sigmoid()
  99. )
  100. )
  101. (shortcut): Sequential()
  102. )
  103. (2): Bottleneck(
  104. (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  105. (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  106. (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  107. (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  108. (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  109. (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  110. (SE): SE_Block(
  111. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  112. (fc): Sequential(
  113. (0): Linear(in_features=512, out_features=32, bias=False)
  114. (1): ReLU()
  115. (2): Linear(in_features=32, out_features=512, bias=False)
  116. (3): Sigmoid()
  117. )
  118. )
  119. (shortcut): Sequential()
  120. )
  121. (3): Bottleneck(
  122. (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  123. (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  124. (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  125. (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  126. (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  127. (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  128. (SE): SE_Block(
  129. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  130. (fc): Sequential(
  131. (0): Linear(in_features=512, out_features=32, bias=False)
  132. (1): ReLU()
  133. (2): Linear(in_features=32, out_features=512, bias=False)
  134. (3): Sigmoid()
  135. )
  136. )
  137. (shortcut): Sequential()
  138. )
  139. )
  140. (layer3): Sequential(
  141. (0): Bottleneck(
  142. (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  143. (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  144. (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  145. (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  146. (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  147. (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  148. (SE): SE_Block(
  149. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  150. (fc): Sequential(
  151. (0): Linear(in_features=1024, out_features=64, bias=False)
  152. (1): ReLU()
  153. (2): Linear(in_features=64, out_features=1024, bias=False)
  154. (3): Sigmoid()
  155. )
  156. )
  157. (shortcut): Sequential(
  158. (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
  159. (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  160. )
  161. )
  162. (1): Bottleneck(
  163. (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  164. (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  165. (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  166. (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  167. (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  168. (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  169. (SE): SE_Block(
  170. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  171. (fc): Sequential(
  172. (0): Linear(in_features=1024, out_features=64, bias=False)
  173. (1): ReLU()
  174. (2): Linear(in_features=64, out_features=1024, bias=False)
  175. (3): Sigmoid()
  176. )
  177. )
  178. (shortcut): Sequential()
  179. )
  180. (2): Bottleneck(
  181. (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  182. (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  183. (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  184. (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  185. (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  186. (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  187. (SE): SE_Block(
  188. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  189. (fc): Sequential(
  190. (0): Linear(in_features=1024, out_features=64, bias=False)
  191. (1): ReLU()
  192. (2): Linear(in_features=64, out_features=1024, bias=False)
  193. (3): Sigmoid()
  194. )
  195. )
  196. (shortcut): Sequential()
  197. )
  198. (3): Bottleneck(
  199. (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  200. (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  201. (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  202. (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  203. (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  204. (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  205. (SE): SE_Block(
  206. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  207. (fc): Sequential(
  208. (0): Linear(in_features=1024, out_features=64, bias=False)
  209. (1): ReLU()
  210. (2): Linear(in_features=64, out_features=1024, bias=False)
  211. (3): Sigmoid()
  212. )
  213. )
  214. (shortcut): Sequential()
  215. )
  216. (4): Bottleneck(
  217. (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  218. (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  219. (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  220. (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  221. (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  222. (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  223. (SE): SE_Block(
  224. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  225. (fc): Sequential(
  226. (0): Linear(in_features=1024, out_features=64, bias=False)
  227. (1): ReLU()
  228. (2): Linear(in_features=64, out_features=1024, bias=False)
  229. (3): Sigmoid()
  230. )
  231. )
  232. (shortcut): Sequential()
  233. )
  234. (5): Bottleneck(
  235. (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  236. (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  237. (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  238. (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  239. (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  240. (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  241. (SE): SE_Block(
  242. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  243. (fc): Sequential(
  244. (0): Linear(in_features=1024, out_features=64, bias=False)
  245. (1): ReLU()
  246. (2): Linear(in_features=64, out_features=1024, bias=False)
  247. (3): Sigmoid()
  248. )
  249. )
  250. (shortcut): Sequential()
  251. )
  252. )
  253. (layer4): Sequential(
  254. (0): Bottleneck(
  255. (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  256. (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  257. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  258. (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  259. (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
  260. (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  261. (SE): SE_Block(
  262. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  263. (fc): Sequential(
  264. (0): Linear(in_features=2048, out_features=128, bias=False)
  265. (1): ReLU()
  266. (2): Linear(in_features=128, out_features=2048, bias=False)
  267. (3): Sigmoid()
  268. )
  269. )
  270. (shortcut): Sequential(
  271. (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
  272. (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  273. )
  274. )
  275. (1): Bottleneck(
  276. (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  277. (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  278. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  279. (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  280. (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
  281. (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  282. (SE): SE_Block(
  283. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  284. (fc): Sequential(
  285. (0): Linear(in_features=2048, out_features=128, bias=False)
  286. (1): ReLU()
  287. (2): Linear(in_features=128, out_features=2048, bias=False)
  288. (3): Sigmoid()
  289. )
  290. )
  291. (shortcut): Sequential()
  292. )
  293. (2): Bottleneck(
  294. (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  295. (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  296. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  297. (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  298. (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
  299. (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  300. (SE): SE_Block(
  301. (gap): AdaptiveAvgPool2d(output_size=(1, 1))
  302. (fc): Sequential(
  303. (0): Linear(in_features=2048, out_features=128, bias=False)
  304. (1): ReLU()
  305. (2): Linear(in_features=128, out_features=2048, bias=False)
  306. (3): Sigmoid()
  307. )
  308. )
  309. (shortcut): Sequential()
  310. )
  311. )
  312. (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  313. (linear): Linear(in_features=2048, out_features=10, bias=True)
  314. )
  315. torch.Size([1, 10])

 使用torchsummary打印每个网络模型的详细信息

  1. if __name__ == '__main__':
  2. net = SE_ResNet50().cuda()
  3. summary(net, (3, 224, 224))

打印模型如下

  1. ----------------------------------------------------------------
  2. Layer (type) Output Shape Param #
  3. ================================================================
  4. Conv2d-1 [-1, 64, 224, 224] 1,728
  5. BatchNorm2d-2 [-1, 64, 224, 224] 128
  6. Conv2d-3 [-1, 64, 224, 224] 4,096
  7. BatchNorm2d-4 [-1, 64, 224, 224] 128
  8. Conv2d-5 [-1, 64, 224, 224] 36,864
  9. BatchNorm2d-6 [-1, 64, 224, 224] 128
  10. Conv2d-7 [-1, 256, 224, 224] 16,384
  11. BatchNorm2d-8 [-1, 256, 224, 224] 512
  12. AdaptiveAvgPool2d-9 [-1, 256, 1, 1] 0
  13. Linear-10 [-1, 16] 4,096
  14. ReLU-11 [-1, 16] 0
  15. Linear-12 [-1, 256] 4,096
  16. Sigmoid-13 [-1, 256] 0
  17. SE_Block-14 [-1, 256, 224, 224] 0
  18. Conv2d-15 [-1, 256, 224, 224] 16,384
  19. BatchNorm2d-16 [-1, 256, 224, 224] 512
  20. Bottleneck-17 [-1, 256, 224, 224] 0
  21. Conv2d-18 [-1, 64, 224, 224] 16,384
  22. BatchNorm2d-19 [-1, 64, 224, 224] 128
  23. Conv2d-20 [-1, 64, 224, 224] 36,864
  24. BatchNorm2d-21 [-1, 64, 224, 224] 128
  25. Conv2d-22 [-1, 256, 224, 224] 16,384
  26. BatchNorm2d-23 [-1, 256, 224, 224] 512
  27. AdaptiveAvgPool2d-24 [-1, 256, 1, 1] 0
  28. Linear-25 [-1, 16] 4,096
  29. ReLU-26 [-1, 16] 0
  30. Linear-27 [-1, 256] 4,096
  31. Sigmoid-28 [-1, 256] 0
  32. SE_Block-29 [-1, 256, 224, 224] 0
  33. Bottleneck-30 [-1, 256, 224, 224] 0
  34. Conv2d-31 [-1, 64, 224, 224] 16,384
  35. BatchNorm2d-32 [-1, 64, 224, 224] 128
  36. Conv2d-33 [-1, 64, 224, 224] 36,864
  37. BatchNorm2d-34 [-1, 64, 224, 224] 128
  38. Conv2d-35 [-1, 256, 224, 224] 16,384
  39. BatchNorm2d-36 [-1, 256, 224, 224] 512
  40. AdaptiveAvgPool2d-37 [-1, 256, 1, 1] 0
  41. Linear-38 [-1, 16] 4,096
  42. ReLU-39 [-1, 16] 0
  43. Linear-40 [-1, 256] 4,096
  44. Sigmoid-41 [-1, 256] 0
  45. SE_Block-42 [-1, 256, 224, 224] 0
  46. Bottleneck-43 [-1, 256, 224, 224] 0
  47. Conv2d-44 [-1, 128, 224, 224] 32,768
  48. BatchNorm2d-45 [-1, 128, 224, 224] 256
  49. Conv2d-46 [-1, 128, 112, 112] 147,456
  50. BatchNorm2d-47 [-1, 128, 112, 112] 256
  51. Conv2d-48 [-1, 512, 112, 112] 65,536
  52. BatchNorm2d-49 [-1, 512, 112, 112] 1,024
  53. AdaptiveAvgPool2d-50 [-1, 512, 1, 1] 0
  54. Linear-51 [-1, 32] 16,384
  55. ReLU-52 [-1, 32] 0
  56. Linear-53 [-1, 512] 16,384
  57. Sigmoid-54 [-1, 512] 0
  58. SE_Block-55 [-1, 512, 112, 112] 0
  59. Conv2d-56 [-1, 512, 112, 112] 131,072
  60. BatchNorm2d-57 [-1, 512, 112, 112] 1,024
  61. Bottleneck-58 [-1, 512, 112, 112] 0
  62. Conv2d-59 [-1, 128, 112, 112] 65,536
  63. BatchNorm2d-60 [-1, 128, 112, 112] 256
  64. Conv2d-61 [-1, 128, 112, 112] 147,456
  65. BatchNorm2d-62 [-1, 128, 112, 112] 256
  66. Conv2d-63 [-1, 512, 112, 112] 65,536
  67. BatchNorm2d-64 [-1, 512, 112, 112] 1,024
  68. AdaptiveAvgPool2d-65 [-1, 512, 1, 1] 0
  69. Linear-66 [-1, 32] 16,384
  70. ReLU-67 [-1, 32] 0
  71. Linear-68 [-1, 512] 16,384
  72. Sigmoid-69 [-1, 512] 0
  73. SE_Block-70 [-1, 512, 112, 112] 0
  74. Bottleneck-71 [-1, 512, 112, 112] 0
  75. Conv2d-72 [-1, 128, 112, 112] 65,536
  76. BatchNorm2d-73 [-1, 128, 112, 112] 256
  77. Conv2d-74 [-1, 128, 112, 112] 147,456
  78. BatchNorm2d-75 [-1, 128, 112, 112] 256
  79. Conv2d-76 [-1, 512, 112, 112] 65,536
  80. BatchNorm2d-77 [-1, 512, 112, 112] 1,024
  81. AdaptiveAvgPool2d-78 [-1, 512, 1, 1] 0
  82. Linear-79 [-1, 32] 16,384
  83. ReLU-80 [-1, 32] 0
  84. Linear-81 [-1, 512] 16,384
  85. Sigmoid-82 [-1, 512] 0
  86. SE_Block-83 [-1, 512, 112, 112] 0
  87. Bottleneck-84 [-1, 512, 112, 112] 0
  88. Conv2d-85 [-1, 128, 112, 112] 65,536
  89. BatchNorm2d-86 [-1, 128, 112, 112] 256
  90. Conv2d-87 [-1, 128, 112, 112] 147,456
  91. BatchNorm2d-88 [-1, 128, 112, 112] 256
  92. Conv2d-89 [-1, 512, 112, 112] 65,536
  93. BatchNorm2d-90 [-1, 512, 112, 112] 1,024
  94. AdaptiveAvgPool2d-91 [-1, 512, 1, 1] 0
  95. Linear-92 [-1, 32] 16,384
  96. ReLU-93 [-1, 32] 0
  97. Linear-94 [-1, 512] 16,384
  98. Sigmoid-95 [-1, 512] 0
  99. SE_Block-96 [-1, 512, 112, 112] 0
  100. Bottleneck-97 [-1, 512, 112, 112] 0
  101. Conv2d-98 [-1, 256, 112, 112] 131,072
  102. BatchNorm2d-99 [-1, 256, 112, 112] 512
  103. Conv2d-100 [-1, 256, 56, 56] 589,824
  104. BatchNorm2d-101 [-1, 256, 56, 56] 512
  105. Conv2d-102 [-1, 1024, 56, 56] 262,144
  106. BatchNorm2d-103 [-1, 1024, 56, 56] 2,048
  107. AdaptiveAvgPool2d-104 [-1, 1024, 1, 1] 0
  108. Linear-105 [-1, 64] 65,536
  109. ReLU-106 [-1, 64] 0
  110. Linear-107 [-1, 1024] 65,536
  111. Sigmoid-108 [-1, 1024] 0
  112. SE_Block-109 [-1, 1024, 56, 56] 0
  113. Conv2d-110 [-1, 1024, 56, 56] 524,288
  114. BatchNorm2d-111 [-1, 1024, 56, 56] 2,048
  115. Bottleneck-112 [-1, 1024, 56, 56] 0
  116. Conv2d-113 [-1, 256, 56, 56] 262,144
  117. BatchNorm2d-114 [-1, 256, 56, 56] 512
  118. Conv2d-115 [-1, 256, 56, 56] 589,824
  119. BatchNorm2d-116 [-1, 256, 56, 56] 512
  120. Conv2d-117 [-1, 1024, 56, 56] 262,144
  121. BatchNorm2d-118 [-1, 1024, 56, 56] 2,048
  122. AdaptiveAvgPool2d-119 [-1, 1024, 1, 1] 0
  123. Linear-120 [-1, 64] 65,536
  124. ReLU-121 [-1, 64] 0
  125. Linear-122 [-1, 1024] 65,536
  126. Sigmoid-123 [-1, 1024] 0
  127. SE_Block-124 [-1, 1024, 56, 56] 0
  128. Bottleneck-125 [-1, 1024, 56, 56] 0
  129. Conv2d-126 [-1, 256, 56, 56] 262,144
  130. BatchNorm2d-127 [-1, 256, 56, 56] 512
  131. Conv2d-128 [-1, 256, 56, 56] 589,824
  132. BatchNorm2d-129 [-1, 256, 56, 56] 512
  133. Conv2d-130 [-1, 1024, 56, 56] 262,144
  134. BatchNorm2d-131 [-1, 1024, 56, 56] 2,048
  135. AdaptiveAvgPool2d-132 [-1, 1024, 1, 1] 0
  136. Linear-133 [-1, 64] 65,536
  137. ReLU-134 [-1, 64] 0
  138. Linear-135 [-1, 1024] 65,536
  139. Sigmoid-136 [-1, 1024] 0
  140. SE_Block-137 [-1, 1024, 56, 56] 0
  141. Bottleneck-138 [-1, 1024, 56, 56] 0
  142. Conv2d-139 [-1, 256, 56, 56] 262,144
  143. BatchNorm2d-140 [-1, 256, 56, 56] 512
  144. Conv2d-141 [-1, 256, 56, 56] 589,824
  145. BatchNorm2d-142 [-1, 256, 56, 56] 512
  146. Conv2d-143 [-1, 1024, 56, 56] 262,144
  147. BatchNorm2d-144 [-1, 1024, 56, 56] 2,048
  148. AdaptiveAvgPool2d-145 [-1, 1024, 1, 1] 0
  149. Linear-146 [-1, 64] 65,536
  150. ReLU-147 [-1, 64] 0
  151. Linear-148 [-1, 1024] 65,536
  152. Sigmoid-149 [-1, 1024] 0
  153. SE_Block-150 [-1, 1024, 56, 56] 0
  154. Bottleneck-151 [-1, 1024, 56, 56] 0
  155. Conv2d-152 [-1, 256, 56, 56] 262,144
  156. BatchNorm2d-153 [-1, 256, 56, 56] 512
  157. Conv2d-154 [-1, 256, 56, 56] 589,824
  158. BatchNorm2d-155 [-1, 256, 56, 56] 512
  159. Conv2d-156 [-1, 1024, 56, 56] 262,144
  160. BatchNorm2d-157 [-1, 1024, 56, 56] 2,048
  161. AdaptiveAvgPool2d-158 [-1, 1024, 1, 1] 0
  162. Linear-159 [-1, 64] 65,536
  163. ReLU-160 [-1, 64] 0
  164. Linear-161 [-1, 1024] 65,536
  165. Sigmoid-162 [-1, 1024] 0
  166. SE_Block-163 [-1, 1024, 56, 56] 0
  167. Bottleneck-164 [-1, 1024, 56, 56] 0
  168. Conv2d-165 [-1, 256, 56, 56] 262,144
  169. BatchNorm2d-166 [-1, 256, 56, 56] 512
  170. Conv2d-167 [-1, 256, 56, 56] 589,824
  171. BatchNorm2d-168 [-1, 256, 56, 56] 512
  172. Conv2d-169 [-1, 1024, 56, 56] 262,144
  173. BatchNorm2d-170 [-1, 1024, 56, 56] 2,048
  174. AdaptiveAvgPool2d-171 [-1, 1024, 1, 1] 0
  175. Linear-172 [-1, 64] 65,536
  176. ReLU-173 [-1, 64] 0
  177. Linear-174 [-1, 1024] 65,536
  178. Sigmoid-175 [-1, 1024] 0
  179. SE_Block-176 [-1, 1024, 56, 56] 0
  180. Bottleneck-177 [-1, 1024, 56, 56] 0
  181. Conv2d-178 [-1, 512, 56, 56] 524,288
  182. BatchNorm2d-179 [-1, 512, 56, 56] 1,024
  183. Conv2d-180 [-1, 512, 28, 28] 2,359,296
  184. BatchNorm2d-181 [-1, 512, 28, 28] 1,024
  185. Conv2d-182 [-1, 2048, 28, 28] 1,048,576
  186. BatchNorm2d-183 [-1, 2048, 28, 28] 4,096
  187. AdaptiveAvgPool2d-184 [-1, 2048, 1, 1] 0
  188. Linear-185 [-1, 128] 262,144
  189. ReLU-186 [-1, 128] 0
  190. Linear-187 [-1, 2048] 262,144
  191. Sigmoid-188 [-1, 2048] 0
  192. SE_Block-189 [-1, 2048, 28, 28] 0
  193. Conv2d-190 [-1, 2048, 28, 28] 2,097,152
  194. BatchNorm2d-191 [-1, 2048, 28, 28] 4,096
  195. Bottleneck-192 [-1, 2048, 28, 28] 0
  196. Conv2d-193 [-1, 512, 28, 28] 1,048,576
  197. BatchNorm2d-194 [-1, 512, 28, 28] 1,024
  198. Conv2d-195 [-1, 512, 28, 28] 2,359,296
  199. BatchNorm2d-196 [-1, 512, 28, 28] 1,024
  200. Conv2d-197 [-1, 2048, 28, 28] 1,048,576
  201. BatchNorm2d-198 [-1, 2048, 28, 28] 4,096
  202. AdaptiveAvgPool2d-199 [-1, 2048, 1, 1] 0
  203. Linear-200 [-1, 128] 262,144
  204. ReLU-201 [-1, 128] 0
  205. Linear-202 [-1, 2048] 262,144
  206. Sigmoid-203 [-1, 2048] 0
  207. SE_Block-204 [-1, 2048, 28, 28] 0
  208. Bottleneck-205 [-1, 2048, 28, 28] 0
  209. Conv2d-206 [-1, 512, 28, 28] 1,048,576
  210. BatchNorm2d-207 [-1, 512, 28, 28] 1,024
  211. Conv2d-208 [-1, 512, 28, 28] 2,359,296
  212. BatchNorm2d-209 [-1, 512, 28, 28] 1,024
  213. Conv2d-210 [-1, 2048, 28, 28] 1,048,576
  214. BatchNorm2d-211 [-1, 2048, 28, 28] 4,096
  215. AdaptiveAvgPool2d-212 [-1, 2048, 1, 1] 0
  216. Linear-213 [-1, 128] 262,144
  217. ReLU-214 [-1, 128] 0
  218. Linear-215 [-1, 2048] 262,144
  219. Sigmoid-216 [-1, 2048] 0
  220. SE_Block-217 [-1, 2048, 28, 28] 0
  221. Bottleneck-218 [-1, 2048, 28, 28] 0
  222. AdaptiveAvgPool2d-219 [-1, 2048, 1, 1] 0
  223. Linear-220 [-1, 10] 20,490
  224. ================================================================
  225. Total params: 26,035,786
  226. Trainable params: 26,035,786
  227. Non-trainable params: 0
  228. ----------------------------------------------------------------
  229. Input size (MB): 0.57
  230. Forward/backward pass size (MB): 3914.25
  231. Params size (MB): 99.32
  232. Estimated Total Size (MB): 4014.14
  233. ----------------------------------------------------------------
  234. Process finished with exit code 0

(5)完整代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchsummary import summary
  5. '''-------------一、SE模块-----------------------------'''
  6. #全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
  7. class SE_Block(nn.Module):
  8. def __init__(self, inchannel, ratio=16):
  9. super(SE_Block, self).__init__()
  10. # 全局平均池化(Fsq操作)
  11. self.gap = nn.AdaptiveAvgPool2d((1, 1))
  12. # 两个全连接层(Fex操作)
  13. self.fc = nn.Sequential(
  14. nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
  15. nn.ReLU(),
  16. nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
  17. nn.Sigmoid()
  18. )
  19. def forward(self, x):
  20. # 读取批数据图片数量及通道数
  21. b, c, h, w = x.size()
  22. # Fsq操作:经池化后输出b*c的矩阵
  23. y = self.gap(x).view(b, c)
  24. # Fex操作:经全连接层输出(b,c,1,1)矩阵
  25. y = self.fc(y).view(b, c, 1, 1)
  26. # Fscale操作:将得到的权重乘以原来的特征图x
  27. return x * y.expand_as(x)
  28. '''-------------二、BasicBlock模块-----------------------------'''
  29. # 左侧的 residual block 结构(18-layer、34-layer)
  30. class BasicBlock(nn.Module):
  31. expansion = 1
  32. def __init__(self, inchannel, outchannel, stride=1):
  33. super(BasicBlock, self).__init__()
  34. self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=3,
  35. stride=stride, padding=1, bias=False)
  36. self.bn1 = nn.BatchNorm2d(outchannel)
  37. self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3,
  38. stride=1, padding=1, bias=False)
  39. self.bn2 = nn.BatchNorm2d(outchannel)
  40. # SE_Block放在BN之后,shortcut之前
  41. self.SE = SE_Block(outchannel)
  42. self.shortcut = nn.Sequential()
  43. if stride != 1 or inchannel != self.expansion*outchannel:
  44. self.shortcut = nn.Sequential(
  45. nn.Conv2d(inchannel, self.expansion*outchannel,
  46. kernel_size=1, stride=stride, bias=False),
  47. nn.BatchNorm2d(self.expansion*outchannel)
  48. )
  49. def forward(self, x):
  50. out = F.relu(self.bn1(self.conv1(x)))
  51. out = self.bn2(self.conv2(out))
  52. SE_out = self.SE(out)
  53. out = out * SE_out
  54. out += self.shortcut(x)
  55. out = F.relu(out)
  56. return out
  57. '''-------------三、Bottleneck模块-----------------------------'''
  58. # 右侧的 residual block 结构(50-layer、101-layer、152-layer)
  59. class Bottleneck(nn.Module):
  60. expansion = 4
  61. def __init__(self, inchannel, outchannel, stride=1):
  62. super(Bottleneck, self).__init__()
  63. self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=1, bias=False)
  64. self.bn1 = nn.BatchNorm2d(outchannel)
  65. self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3,
  66. stride=stride, padding=1, bias=False)
  67. self.bn2 = nn.BatchNorm2d(outchannel)
  68. self.conv3 = nn.Conv2d(outchannel, self.expansion*outchannel,
  69. kernel_size=1, bias=False)
  70. self.bn3 = nn.BatchNorm2d(self.expansion*outchannel)
  71. # SE_Block放在BN之后,shortcut之前
  72. self.SE = SE_Block(self.expansion*outchannel)
  73. self.shortcut = nn.Sequential()
  74. if stride != 1 or inchannel != self.expansion*outchannel:
  75. self.shortcut = nn.Sequential(
  76. nn.Conv2d(inchannel, self.expansion*outchannel,
  77. kernel_size=1, stride=stride, bias=False),
  78. nn.BatchNorm2d(self.expansion*outchannel)
  79. )
  80. def forward(self, x):
  81. out = F.relu(self.bn1(self.conv1(x)))
  82. out = F.relu(self.bn2(self.conv2(out)))
  83. out = self.bn3(self.conv3(out))
  84. SE_out = self.SE(out)
  85. out = out * SE_out
  86. out += self.shortcut(x)
  87. out = F.relu(out)
  88. return out
  89. '''-------------四、搭建SE_ResNet结构-----------------------------'''
  90. class SE_ResNet(nn.Module):
  91. def __init__(self, block, num_blocks, num_classes=10):
  92. super(SE_ResNet, self).__init__()
  93. self.in_planes = 64
  94. self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
  95. stride=1, padding=1, bias=False) # conv1
  96. self.bn1 = nn.BatchNorm2d(64)
  97. self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) # conv2_x
  98. self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) # conv3_x
  99. self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) # conv4_x
  100. self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # conv5_x
  101. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  102. self.linear = nn.Linear(512 * block.expansion, num_classes)
  103. def _make_layer(self, block, planes, num_blocks, stride):
  104. strides = [stride] + [1]*(num_blocks-1)
  105. layers = []
  106. for stride in strides:
  107. layers.append(block(self.in_planes, planes, stride))
  108. self.in_planes = planes * block.expansion
  109. return nn.Sequential(*layers)
  110. def forward(self, x):
  111. x = F.relu(self.bn1(self.conv1(x)))
  112. x = self.layer1(x)
  113. x = self.layer2(x)
  114. x = self.layer3(x)
  115. x = self.layer4(x)
  116. x = self.avgpool(x)
  117. x = torch.flatten(x, 1)
  118. out = self.linear(x)
  119. return out
  120. def SE_ResNet18():
  121. return SE_ResNet(BasicBlock, [2, 2, 2, 2])
  122. def SE_ResNet34():
  123. return SE_ResNet(BasicBlock, [3, 4, 6, 3])
  124. def SE_ResNet50():
  125. return SE_ResNet(Bottleneck, [3, 4, 6, 3])
  126. def SE_ResNet101():
  127. return SE_ResNet(Bottleneck, [3, 4, 23, 3])
  128. def SE_ResNet152():
  129. return SE_ResNet(Bottleneck, [3, 8, 36, 3])
  130. '''
  131. if __name__ == '__main__':
  132. model = SE_ResNet50()
  133. print(model)
  134. input = torch.randn(1, 3, 224, 224)
  135. out = model(input)
  136. print(out.shape)
  137. # test()
  138. '''
  139. if __name__ == '__main__':
  140. net = SE_ResNet50().cuda()
  141. summary(net, (3, 224, 224))

本篇就结束了,欢迎大家留言讨论呀!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/290861
推荐阅读
相关标签
  

闽ICP备14008679号