当前位置:   article > 正文

SE-NET se注意力机制应用于ResNet (附代码)_图像分类resnet加注意力机制代码

图像分类resnet加注意力机制代码

论文地址:https://arxiv.org/abs/1709.01507

代码地址:https://github.com/madao33/computer-vision-learning

1.是什么?

SE-NET网络是一种基于卷积神经网络的模型,它引入了SE(Squeeze-and-Excitation)块来增强通道之间的相互关系。SE块通过学习每个通道的重要性权重,使得有用的特征被放大,没有用的特征被抑制。SE块的实现需要满足两个标准:灵活性和学习非互斥关系。SE-NET网络在图像分类、目标检测和语义分割等任务中都取得了很好的效果。在语义分割任务中,SE-NET网络可以与UNet和DenseNet等基准网络结合使用,提高分割精度。

2.为什么?

SE结构设计的原因

对于图像,其输出都是由所有通道的求和产生,所以通道依赖关系隐含地嵌入到特征中,同时与滤波器捕获的局部空间相关性纠缠在一起。由卷积建模的通道关系本质上是隐式的和局部的(除了在最顶层的那些)。

作者希望通过明确地建模通道相互依赖来增强卷积特征的学习,从而使网络能够增加其对信息特征的敏感性,这些信息特征可被随后的转换所利用。

因此,作者希望为它提供获取全局信息的途径,并在它们被输入到下一个转换之前,通过挤压和激励两个步骤重新校准过滤器响应。

两个层的作用

1 挤压层:嵌入全局信息

为了解决通道依赖关系的问题,首先考虑信号到每个通道的输出特性。每个学习过的过滤器都使用一个局部接受域操作,因此转换输出 U 的每个单元无法利用该区域以外的上下文信息。

为了缓解这个问题,作者建议将全局空间信息压缩到一个通道描述符中。变换 U 的输出可以被解释为局部描述符的集合,这些描述符的统计信息表达了整个图像。

2 激励层:自适应重校

为了利用在挤压操作中聚合的信息,我们在它之后执行第二个操作,目的是完全捕获通道方面的依赖项。

为了实现这一目标,该操作必须符合两个标准:首先,它必须是灵活的(必须能够学习通道之间的非线性交互),其次,它必须学习一种非互斥关系,因为我们希望确保允许强调多个通道(而不是强制执行独热激活)。

为了满足这些标准,选择使用一个简单的带有sigmoid激活的门控机制。

为了限制模型的复杂性,通过在非线性周围形成两个全连接(FC)层的瓶颈来参数化门控机制,即降维层,降维率为 r,接一个ReLU,然后一个维度增加层返回到转换输出U的通道维度。块的最终输出是通过将U与激活sigmoid而获得的。

激励算子将特定输入的描述符映射为一组信道权值

3.怎么样?

3.1 SE块结构

SE块结构如下图所示,其先将给定信息 X 经 F 转换映射到 U,然后经过挤压操作,即对每个通道的整个空间维度 (H×W) 进行特征聚合映射,最后经过激励层,其采用一种简单的自选门机制形式,将嵌入作为输入,并产生每通道调制权值的集合。这些权重被应用到特征映射U上,生成SE块的输出,该输出可以直接输入到网络的后续层

可以通过简单地堆叠SE块的集合来构建SE网络(SENet)。此外,这些SE块也可以作为一个插入式替换原始块。

虽然构建块的模板是通用的,但它在整个网络的不同深度上所扮演的角色是不同的。在早期的层中,它以一种与类无关的方式激发信息特性,加强共享的低级表示。在后面的层中,SE块变得越来越专门化,并以一种高度特定于类的方式响应不同的输入。因此,特征重新校准的好处得以实现。

3.2 如何嵌入其他结构?

SE三种变体:

  1. SE- PRE块,SE位于残差块之前;
  2. SE- POST块,SE位于残差块之后;
  3. SE- Identity块,SE与残差块并行。

3.3 SE嵌入模型框架 

3.4 代码实现

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def conv_block(in_channel, out_channel, relu_last=True, **kwargs):
  5. layers = [nn.Conv2d(in_channel, out_channel, bias=False, **kwargs),
  6. nn.BatchNorm2d(out_channel)]
  7. if relu_last:
  8. layers.append(nn.ReLU(inplace=True))
  9. return nn.Sequential(*layers)
  10. class ResidualSEBlock(nn.Module):
  11. expansion = 1
  12. def __init__(self, in_channel, out_channel, stride, r=16):
  13. super(ResidualSEBlock, self).__init__()
  14. self.residual = nn.Sequential(
  15. conv_block(in_channel, out_channel, kernel_size=3, stride=stride, padding=1),
  16. conv_block(out_channel, out_channel * self.expansion, kernel_size=3, padding=1)
  17. )
  18. self.shortcut = nn.Sequential()
  19. if stride != 1 or in_channel != out_channel * self.expansion:
  20. self.shortcut = conv_block(in_channel, out_channel * self.expansion, kernel_size=1, stride=stride,
  21. relu_last=False)
  22. self.squeeze = nn.AdaptiveAvgPool2d(1)
  23. self.excitation = nn.Sequential(
  24. nn.Linear(out_channel * self.expansion, out_channel * self.expansion // r),
  25. nn.ReLU(inplace=True),
  26. nn.Linear(out_channel * self.expansion // r, out_channel * self.expansion),
  27. nn.Sigmoid())
  28. def forward(self, x):
  29. r = self.residual(x)
  30. bs, c, _, _ = r.shape
  31. s = self.squeeze(r).view(bs, c)
  32. e = self.excitation(s).view(bs, c, 1, 1)
  33. return F.relu(self.shortcut(x) + r * e.expand_as(r))
  34. class SEResnet(nn.Module):
  35. def __init__(self, in_channel, n_classes, num_blocks, block):
  36. super(SEResnet, self).__init__()
  37. self.in_channels = 64
  38. self.feature = nn.Sequential(
  39. conv_block(in_channel, 64, kernel_size=3, padding=1),
  40. self._make_stage(64, 1, num_blocks[0], block),
  41. self._make_stage(128, 2, num_blocks[1], block),
  42. self._make_stage(256, 2, num_blocks[2], block),
  43. self._make_stage(512, 2, num_blocks[3], block)
  44. )
  45. self.classifier = nn.Sequential(
  46. nn.AdaptiveAvgPool2d((1, 1)),
  47. nn.Flatten(),
  48. nn.Linear(self.in_channels, n_classes),
  49. nn.LogSoftmax(dim=1)
  50. )
  51. def _make_stage(self, out_channel, stride, num_block, block):
  52. layers = []
  53. for i in range(num_block):
  54. stride = stride if i == 0 else 1
  55. layers.append(block(self.in_channels, out_channel, stride))
  56. self.in_channels = out_channel * block.expansion
  57. return nn.Sequential(*layers)
  58. def forward(self, x):
  59. return self.classifier(self.feature(x))
  60. def seresnet18():
  61. return SEResnet(3, 10, [2, 2, 2, 2], ResidualSEBlock)
  62. def seresnet34():
  63. return SEResnet(3, 10, [3, 4, 6, 3], ResidualSEBlock)

参考:

简单实现 SENet

学习Se-net和Sk-net 附网络简单代码(pytorch)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/290874
推荐阅读
相关标签
  

闽ICP备14008679号