当前位置:   article > 正文

目标检测算法——YOLOv5/YOLOv7之结合CA注意力机制_目标检测模型用哪个注意力机制最后

目标检测模型用哪个注意力机制最后

深度学习Tricks,第一时间送达

论文题目:《Coordinate Attention for Efficient Mobile NetWork Design》
论文地址:  https://arxiv.org/pdf/2103.02907.pdf

本文中,作者通过将位置信息嵌入到通道注意力中提出了一种新颖的移动网络注意力机制,将其称为“Coordinate Attention”。与通过2维全局池化将特征张量转换为单个特征向量的通道注意力不同,Coordinate注意力将通道注意力分解为两个1维特征编码过程,分别沿2个空间方向聚合特征。这样,可以沿一个空间方向捕获远程依赖关系,同时可以沿另一空间方向保留精确的位置信息。然后将生成的特征图分别编码为一对方向感知和位置敏感的attention map,可以将其互补地应用于输入特征图,以增强关注对象的表示。

小海带将CA注意力模块嵌入到YOLOv5网络中,可进一步强化YOLOv5网络对方向和位置等信息的敏感度,并涨点明显。近期较忙,想要代码的小伙伴请私信~

1.网络结构图

2.CA模块代码

不同于通道注意力将输入通过2D全局池化转化为单个特征向量,CoordAttention将通道注意力分解为两个沿着不同方向聚合特征的1D特征编码过程。这样的好处是可以沿着一个空间方向捕获长程依赖,沿着另一个空间方向保留精确的位置信息。然后,将生成的特征图分别编码,形成一对方向感知和位置敏感的特征图,它们可以互补地应用到输入特征图来增强感兴趣的目标的表示。

CA模块通过精确的位置信息对通道关系和长程依赖进行编码,类似SE模块,也分为两个步骤:坐标信息嵌入(coordinate information embedding)和坐标注意力生成(coordinate attention generation),它的具体结构如下图。

  1. class h_sigmoid(nn.Module):
  2. def __init__(self, inplace=True):
  3. super(h_sigmoid, self).__init__()
  4. self.relu = nn.ReLU6(inplace=inplace)
  5. def forward(self, x):
  6. return self.relu(x + 3) / 6
  7. class h_swish(nn.Module):
  8. def __init__(self, inplace=True):
  9. super(h_swish, self).__init__()
  10. self.sigmoid = h_sigmoid(inplace=inplace)
  11. def forward(self, x):
  12. return x * self.sigmoid(x)
  13. class CoordAtt(nn.Module):
  14. def __init__(self, inp, oup, reduction=32):
  15. super(CoordAtt, self).__init__()
  16. self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
  17. self.pool_w = nn.AdaptiveAvgPool2d((1, None))
  18. mip = max(8, inp // reduction)
  19. self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
  20. self.bn1 = nn.BatchNorm2d(mip)
  21. self.act = h_swish()
  22. self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
  23. self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
  24. def forward(self, x):
  25. identity = x
  26. x_h = self.pool_h(x)
  27. x_w = self.pool_w(x).permute(0, 1, 3, 2)
  28. y = torch.cat([x_h, x_w], dim=2)
  29. y = self.conv1(y)
  30. y = self.bn1(y)
  31. y = self.act(y)
  32. x_h, x_w = torch.split(y, [h, w], dim=2)
  33. x_w = x_w.permute(0, 1, 3, 2)
  34. a_h = self.conv_h(x_h).sigmoid()
  35. a_w = self.conv_w(x_w).sigmoid()
  36. out = identity * a_w * a_h
  37. return out

如何嵌入YOLOv5网络,各位小伙伴请参考上一篇博文~


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/873836

推荐阅读
相关标签