当前位置:   article > 正文

即插即用的涨点模块之注意力机制(CBAMAttention)详解及代码,可应用于检测、分割、分类等各种算法领域

cbam

目录

前言

一、CBAM结构

二、CBAM计算流程

三、CBAM参数

四、代码详解


前言

        CE模块通常只注意了通道特征,但在视觉任务中,空间任务通常更为重要,是不可忽略的,因此CBAM将通道注意力机制与空间注意力机制进行串联,充分关注特征信息。

        什么是空间特征?在深度学习中,空间特征是指描述输入数据在空间维度上的特征信息。对于图像数据而言,空间特征可以涵盖多种信息,包括边缘、角点、纹理、颜色等。这些特征信息可以帮助模型理解图像中不同区域的内容和结构,从而实现诸如目标检测、图像分割、图像分类等任务。在深度学习模型中,通常通过卷积神经网络(CNN)等结构来提取和学习空间特征,这些特征对于模型的表现和性能具有重要的影响。

        什么是空间注意力机制?空间注意力机制是一种注意力机制,用于在深度学习模型中对输入数据的不同空间位置进行加权,以便模型能够更加关注重要的空间位置,从而提高模型的性能和泛化能力。空间注意力机制通常应用在图像处理或自然语言处理等任务中,能够有效地捕捉输入数据在空间维度上的相关性。

        在空间注意力机制中,模型会学习到针对输入数据中不同空间位置的权重,以确定哪些位置对于任务是最重要的。这些权重可以根据输入数据的内容和上下文来自适应地调整,从而实现对不同空间位置的加权组合。通过引入空间注意力机制,模型可以更好地捕捉数据的局部特征和全局结构,从而提高模型的性能和泛化能力。

通道注意力机制

空间注意力机制

关注对象

关注于不同特征通道的重要性

关注于输入数据中不同位置的重要性

操作对象

输入数据的通道维度

输入数据的空间维度

应用范围

处理具有多个特征通道的数据

处理具有空间结构的数据


一、CBAM结构

        CBAM 是由Channel Attention ModuelSpatial Attention Module构成,结构如图1所示。Channel Attention Moduel,结构如图2所示。对输入的特征图分别同时进行最大池化和平均池化,通过对输入形状为(B,C,H,W)的特征图进行最大池化或平均池化操作,将每个通道(C)在空间维度上的信息进行压缩,最终得到形状为(B,C,1,1)的输出,在这个过程中,对于每个通道而言,它的空间信息被最大池化或者平均池化操作压缩为一个单独的值,从而实现了对全局空间信息的压缩和提取。这一步旨在将特征图上的信息集中在通道上,从而更好的在通道上捕捉到输入的特征图的特征信息,利用这两个特征可以大大提高网络的表示能力。共享网络由两个卷积和一个Relu激活函数构成,先降维再升维,这一步旨在减少参数开销,其中MLP中的权重是共享的,所用的输入都用相同的W0和W1权重矩阵进行计算处理,将共享网络应用于每个特征描述子后,使用元素求和(+)来合并输出特征向量,再将输出的特征向量通过sigmoid函数生成权重向量,确保它们的总和为1。Spatial Attention Module,结构如图3所示。对输入的特征图沿通道轴应用平均池化和最大池化,通过平均池化和最大池化操作,可以将输入张量的通道维度(C)压缩为1,从而将全局通道信息整合为一个单一的通道特征图,形状为(B,1,H,W)。在这个过程中,对于每个样本(B),模型会对该样本在通道上的特征进行平均池化,从而实现对全局通道信息的压缩合并。这种操作有助于减少参数数量、减小计算复杂度,同时保留重要的通道特征信息。将获得的两个矩阵在通道上拼接起来(torch.cat),并通过一个卷积层,将通道数再次变成1,使获得的特征信息全部分布在一个通道上,再将通过卷积层的输出通过sigmoid函数生成权重向量。CBAM则是将在Channel Attention Module得到的通道注意力权重乘以输入的原始特征图。这一步用于调整每个通道的特征值,强调重要通道的信息,抑制不重要通道的信息。再将之前在Spatial Attention Module得到的空间注意力权重乘以通过通道注意力机制得到的特征图,最终即得到最终输出结果。(通道和空间注意力机制可以并行或者顺序放置,发现顺序排列比平行排列产生更好结果,我们实验结果表明,通道优先顺序略优于空间优先顺序)

图1 CBAM结构

图2 通道注意力机制

        

图3 空间注意力机制

精读:CBAM(Convolutional Block Attention Module)是一个集成在卷积神经网络中的注意力模块,目的是增强模型的特征表达能力,通过强调重要的特征并抑制不重要的特征。CBAM 通过两个主要部分工作:Channel Attention Module 和 Spatial Attention Module。下面详细解释这两部分的工作原理及其互动方式。

Channel Attention Module (CAM)的核心目的是强调那些对当前任务更重要的特征通道。它通过以下步骤实现:

1.特征压缩:对输入的特征图X,形状为(B,C,H,W),进行最大池化和平均池化。这两种池化操作都在空间维度H×W 上进行,输出的结果是两个形状为(B,C,1,1)的特征图,即每个通道压缩成一个单独的值,分别代表了该通道的最大值和平均值。通过以下步骤实现:

2.维度转换:通过一个小型神经网络(通常是两层MLP),首先将通道数降维以减少参数量,然后再升维恢复到原始通道数。这个小网络包括两个全连接层和一个ReLU激活函数。

3.特征融合与激活:将最大池化和平均池化得到的两个特征图通过共享的MLP处理后,结果相加并通过sigmoid函数,得到每个通道的权重系数。

Spatial Attention Module (SAM) 的目的是在空间上强调更为关键的区域。它的步骤包括:

1.通道压缩:将处理后的特征图X 进行最大池化和平均池化,但这次是沿着通道轴 C,从而压缩所有通道信息到一个单通道图像中。操作结果是两个形状为(B,1,H,W)的特征图。

2. 特征拼接与卷积:将上述两个特征图在通道维度上拼接,然后通过一个卷积层将通道数变为1,最终通过sigmoid函数得到每个空间位置的权重系数。

整合与顺序

1.特征图权重调整:首先,通过Channel Attention Module得到的通道权重乘以原始的特征图X,调整每个通道的重要性。然后,将这个调整后的特征图输入到Spatial Attention Module,进一步调整每个位置的重要性。

2. 顺序优化:实验显示,首先应用Channel Attention(通道注意力)后再应用Spatial Attention(空间注意力)通常效果更好。这是因为,一旦我们确定了最重要的特征通道,再去调整这些通道中各个位置的重要性,能够更精确地强化有用的信息,抑制不必要的信息。

二、CBAM计算流程

 如图 1所示,给定一个输入为CBAM通过 Channel Attention Moduel获得,在Channel Attention Moduel中,先通过全局平均池化和全局最大池化分别获得,其中为Sigmoid函数,MLP结构为Conv-ReLU-Conv,为MLP的权重。

                            

为CBAM通过Spatial Attention Module获得,在Spatial Attention Module中,先通过全局平均池化和全局最大池化分别获得,其中为Sigmoid函数,表示滤波器为7*7的卷积运算。

将F通过Channel Attention Moduel得到的通道注意力权重乘以输入的原始特征图F,以获得,再将通过Spatial Attention Module得到的空间注意力权重乘以通过通道注意力机制得到的特征图,其中为元素乘法,为最终输出

三、CBAM参数

利用thop库的profile函数计算FLOPs和Param。Input:(512,7,7)。

Module

FLOPs

Param

CBAM

95938.0

32866.0

四、代码详解

 

  1. import torch
  2. from torch import nn
  3. from torch.nn import init
  4. class ChannelAttention(nn.Module):
  5. def __init__(self, in_planes, ratio=16):
  6. super(ChannelAttention, self).__init__()
  7. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  8. self.max_pool = nn.AdaptiveMaxPool2d(1)
  9. self.mlp=nn.Sequential(
  10. nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
  11. nn.ReLU(),
  12. nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
  13. )
  14. self.sigmoid = nn.Sigmoid()
  15. def forward(self, x):
  16. avg_out = self.mlp(self.avg_pool(x)) # 通过平均池化压缩全局空间信息: (B,C,H,W)--> (B,C,1,1) ,然后通过MLP降维升维:(B,C,1,1)
  17. max_out = self.mlp(self.max_pool(x)) # 通过最大池化压缩全局空间信息: (B,C,H,W)--> (B,C,1,1) ,然后通过MLP降维升维:(B,C,1,1)
  18. out = avg_out + max_out
  19. return self.sigmoid(out)
  20. class SpatialAttention(nn.Module):
  21. def __init__(self, kernel_size=7):
  22. super(SpatialAttention, self).__init__()
  23. assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
  24. padding = 3 if kernel_size == 7 else 1
  25. self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  26. self.sigmoid = nn.Sigmoid()
  27. def forward(self, x):
  28. avg_out = torch.mean(x, dim=1, keepdim=True) # 通过平均池化压缩全局通道信息:(B,C,H,W)-->(B,1,H,W)
  29. max_out, _ = torch.max(x, dim=1, keepdim=True) # 通过最大池化压缩全局通道信息:(B,C,H,W)-->(B,1,H,W)
  30. x = torch.cat([avg_out, max_out], dim=1) # 在通道上拼接两个矩阵:(B,2,H,W)
  31. x = self.conv1(x) # 通过卷积层得到注意力权重:(B,2,H,W)-->(B,1,H,W)
  32. return self.sigmoid(x)
  33. class CBAM(nn.Module):
  34. def __init__(self, in_planes, ratio=16, kernel_size=7):
  35. super(CBAM, self).__init__()
  36. self.ca = ChannelAttention(in_planes, ratio)
  37. self.sa = SpatialAttention(kernel_size)
  38. def init_weights(self):
  39. for m in self.modules():
  40. if isinstance(m, nn.Conv2d):
  41. init.kaiming_normal_(m.weight, mode='fan_out')
  42. if m.bias is not None:
  43. init.constant_(m.bias, 0)
  44. elif isinstance(m, nn.BatchNorm2d):
  45. init.constant_(m.weight, 1)
  46. init.constant_(m.bias, 0)
  47. elif isinstance(m, nn.Linear):
  48. init.normal_(m.weight, std=0.001)
  49. if m.bias is not None:
  50. init.constant_(m.bias, 0)
  51. def forward(self, x):
  52. out = x * self.ca(x) # 通过通道注意力机制得到的特征图,x:(B,C,H,W),ca(x):(B,C,1,1),out:(B,C,H,W)
  53. result = out * self.sa(out) # 通过空间注意力机制得到的特征图,out:(B,C,H,W),sa(out):(B,1,H,W),result:(B,C,H,W)
  54. return result
  55. if __name__ == '__main__':
  56. from torchsummary import summary
  57. from thop import profile
  58. model = CBAM(in_planes=512)
  59. # summary(model, (512, 7, 7), device='cpu', batch_size=1)
  60. flops, params = profile(model, inputs=(torch.randn(1, 512, 7, 7),))
  61. print(f"FLOPs: {flops}, Params: {params}")

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/880475
推荐阅读
相关标签
  

闽ICP备14008679号