当前位置:   article > 正文

即插即用的涨点模块之注意力机制(SEAttention)详解及代码,可应用于检测、分割、分类等各种算法领域

seattention

目录

前言

一、SENet结构

二、SENet计算流程

三、SENet参数

四、代码讲解 


前言

Squeeze-and-Excitation Networks(SENet)

来源:CVPR2018

官方代码:GitHub - hujie-frank/SENet: Squeeze-and-Excitation Networks

        什么是通道特征?通道特征(Channel Features)是指卷积神经网络(CNN)中每个卷积核产生的输出。一个通道对应于网络中的一个卷积核,而每个通道的输出表示该卷积核在输入上的响应。通道特征捕捉了输入数据中不同方面的抽象信息。每个通道对应于某种特定的抽象特征,例如纹理、颜色、边缘等。通道特征在整个网络中负责提取和表示不同层次的信息。

        什么是通道注意力机制?通道注意力机制(Channel Attention Mechanism)是深度学习中一种用于增强通道特征捕捉能力的注意力机制。它主要应用于卷积神经网络(CNN)中,以提高模型对不同通道(channel)的特征的关注度,从而使网络更加有效地学习和利用输入数据的信息。在通道注意力机制中,通过学习每个通道的权重,模型可以在处理特定通道的特征时给予更多的注意力。这有助于网络在学习过程中更好地区分不同通道的重要性,从而提高模型对输入数据的表示能力。


一、SENet结构

SENet是一种通道注意力机制,结构如图1所示。SE注意力模块,由Squeeze操作、Excitation操作、Scale操作三部分组成。Squeeze操作:对输入的特征图进行全局平均池化,将每个通道的特征值降维为一个全局向量。这一步旨在捕捉每个通道的全局信息。Excitation操作:由两个全连接和一个ReLU激活函数、一个Softmax激活函数组成,先进行降维在升维,最后通过sigmoid函数生成权重向量,确保它们的总和为1。Scale操作:将上一步得到的通道注意力权重乘以输入的原始特征图。这一步用于调整每个通道的特征值,强调重要通道的信息,抑制不重要通道的信息。SE注意力模块与Inception、ResNet的结合,分别如图2、图3所示。

图1 SENet block

图2 原始 Inception 模块(左)和 SE-Inception 模块(右)的架构。

图3 原始 Residual 模块(左)和 SE-ResNet 模块(右)的架构

二、SENet计算流程

        如图1所示,给定一个输入X∈H'×W'×C' ,通过一个卷积变化Ftr 得到 UH×W×C 将特征U 经过Squeeze操作Fsq 在空间维度H×W 上聚合特征得到Z∈1×1×C 。接下来进行Excitation操作,得到s。其中W1 、W2 为全连接层后的权重,δ 为ReLU函数,σ 为Sigmoid函数。

s=Fex(z,W)=σ(g(z,W))=σ(W2δ(W1z))

        最后进行Scale操作,将特征U和特征s相乘得到x,通过相乘,注意力机制可以对每个通道进行更细粒度的权重调整,将更多的注意力集中在对任务更为关键的通道上。从而调整该特征值的重要性。这个过程使得网络更加关注那些在给定任务中对应通道上重要的特征。

x=Fscale (U,s)=Us

三、SENet参数

利用thop库的profile函数计算FLOPs和Param。Input:(512,7,7)。

ModuleFLOPsParam
SEAttention9113665536

四、代码讲解 

  1. import torch
  2. from torch import nn
  3. from torch.nn import init
  4. class SEAttention(nn.Module):
  5. def __init__(self, channel=512, reduction=16):
  6. super().__init__()
  7. # 在空间维度上,将H×W压缩为1×1
  8. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  9. # 包含两层全连接,先降维,后升维。最后接一个sigmoid函数
  10. self.fc = nn.Sequential(
  11. nn.Linear(channel, channel // reduction, bias=False),
  12. nn.ReLU(inplace=True),
  13. nn.Linear(channel // reduction, channel, bias=False),
  14. nn.Sigmoid()
  15. )
  16. def init_weights(self):
  17. for m in self.modules():
  18. if isinstance(m, nn.Conv2d):
  19. init.kaiming_normal_(m.weight, mode='fan_out')
  20. if m.bias is not None:
  21. init.constant_(m.bias, 0)
  22. elif isinstance(m, nn.BatchNorm2d):
  23. init.constant_(m.weight, 1)
  24. init.constant_(m.bias, 0)
  25. elif isinstance(m, nn.Linear):
  26. init.normal_(m.weight, std=0.001)
  27. if m.bias is not None:
  28. init.constant_(m.bias, 0)
  29. def forward(self, x):
  30. # (B,C,H,W)
  31. B, C, H, W = x.size()
  32. # Squeeze: (B,C,H,W)-->avg_pool-->(B,C,1,1)-->view-->(B,C)
  33. y = self.avg_pool(x).view(B, C)
  34. # Excitation: (B,C)-->fc-->(B,C)-->(B, C, 1, 1)
  35. y = self.fc(y).view(B, C, 1, 1)
  36. # scale: (B,C,H,W) * (B, C, 1, 1) == (B,C,H,W)
  37. out = x * y
  38. return out
  39. if __name__ == '__main__':
  40. from torchsummary import summary
  41. from thop import profile
  42. model = SEAttention(channel=512, reduction=8)
  43. summary(model, (512, 7, 7), device='cpu')
  44. flops, params = profile(model, inputs=(torch.randn(1, 512, 7, 7),))
  45. print(f"FLOPs: {flops}, Params: {params}")

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/603892
推荐阅读
相关标签
  

闽ICP备14008679号