赞
踩
一般来说,注意力机制通常被分为以下基本四大类:
通道注意力 Channel Attention
空间注意力机制 Spatial Attention
时间注意力机制 Temporal Attention
分支注意力机制 Branch Attention
轻量级的卷积注意力模块,它结合了通道和空间的注意力机制模块
论文题目:《CBAM: Convolutional Block Attention Module》
论文地址: https://arxiv.org/pdf/1807.06521.pdf
上图可以看到,CBAM包含CAM(Channel Attention Module)和SAM(Spartial Attention Module)两个子模块,分别进行通道和空间上的Attention。这样不只能够节约参数和计算力,并且保证了其能够做为即插即用的模块集成到现有的网络架构中去。
超越CBAM,全新注意力GAM:不计成本提高精度!
论文题目:Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions
论文地址:https://paperswithcode.com/paper/global-attention-mechanism-retain-information
从整体上可以看出,GAM和CBAM注意力机制还是比较相似的,同样是使用了通道注意力机制和空间注意力机制。但是不同的是对通道注意力和空间注意力的处理。
CBAM结构其实就是将通道注意力信息核空间注意力信息在一个block结构中进行运用。
在resnet中实现cbam:即在原始block和残差结构连接前,依次通过channel attention和spatial attention即可。
modules.py
中(相当于yolov5中的common.py
)- class ChannelAttention(nn.Module):
- # Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
- def __init__(self, channels: int) -> None:
- super().__init__()
- self.pool = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
- self.act = nn.Sigmoid()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x * self.act(self.fc(self.pool(x)))
-
-
- class SpatialAttention(nn.Module):
- # Spatial-attention module
- def __init__(self, kernel_size=7):
- super().__init__()
- assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
- padding = 3 if kernel_size == 7 else 1
- self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
- self.act = nn.Sigmoid()
-
- def forward(self, x):
- return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
-
-
- class CBAM(nn.Module):
- # Convolutional Block Attention Module
- def __init__(self, c1, kernel_size=7): # ch_in, kernels
- super().__init__()
- self.channel_attention = ChannelAttention(c1)
- self.spatial_attention = SpatialAttention(kernel_size)
-
- def forward(self, x):
- return self.spatial_attention(self.channel_attention(x))
modules.py
中:- def channel_shuffle(x, groups=2): ##shuffle channel
- # RESHAPE----->transpose------->Flatten
- B, C, H, W = x.size()
- out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
- out = out.view(B, C, H, W)
- return out
-
- class GAM_Attention(nn.Module):
- # https://paperswithcode.com/paper/global-attention-mechanism-retain-information
- def __init__(self, c1, c2, group=True, rate=4):
- super(GAM_Attention, self).__init__()
-
- self.channel_attention = nn.Sequential(
- nn.Linear(c1, int(c1 / rate)),
- nn.ReLU(inplace=True),
- nn.Linear(int(c1 / rate), c1)
- )
-
- self.spatial_attention = nn.Sequential(
-
- nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),
- kernel_size=7,
- padding=3),
- nn.BatchNorm2d(int(c1 / rate)),
- nn.ReLU(inplace=True),
- nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,
- kernel_size=7,
- padding=3),
- nn.BatchNorm2d(c2)
- )
-
- def forward(self, x):
- b, c, h, w = x.shape
- x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
- x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
- x_channel_att = x_att_permute.permute(0, 3, 1, 2)
- # x_channel_att=channel_shuffle(x_channel_att,4) #last shuffle
- x = x * x_channel_att
-
- x_spatial_att = self.spatial_attention(x).sigmoid()
- x_spatial_att = channel_shuffle(x_spatial_att, 4) # last shuffle
- out = x * x_spatial_att
- # out=channel_shuffle(out,4) #last shuffle
- return out
2.3 CBAM、GAM_Attention、ResBlock_CBAM加入tasks.py中(相当于yolov5中的yolo.py
)
- from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
- Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
- GhostBottleneck, GhostConv, Segment,CBAM, GAM_Attention , ResBlock_CBAM)
def parse_model(d, ch, verbose=True):函数中
- if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
- BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x , CBAM , GAM_Attention ,ResBlock_CBAM):
2.4 CBAM、GAM修改对应yaml
2.4.1 CBAM加入yolov8
- # Ultralytics YOLO 声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/558450推荐阅读
相关标签
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。