当前位置:   article > 正文

YOLOv5、v8改进:CrissCrossAttention注意力机制

crisscrossattention注意力机制

目录

1.简介

2. yolov5添加方法:

2.1common.py构建CrissCrossAttention模块

2.2yolo.py中注册 CrissCrossAttention模块

2.3修改yaml文件。


1.简介

这是ICCV2019的用于语义分割的论文,可以说和CVPR2019的DANet遥相呼应。

和DANet一样,CCNet也是想建模像素之间的long range dependencies,来做更加丰富的contextual information,来补充特征图,以此来提升语义分割的性能。但是和DANet不一样,CCNet仅考虑空间分辨上的建模,不考虑建模通道之间的联系。作者提出的模块,criss-cross attention module,针对空间维度上的建模,对于空间位置的一个点u,仅考虑建模和u在同一行或者同一列的其他位置的像素之间的联系。相比DANet,能减少很多计算量,但是不足的是,对一个点的特征向量,尽管有同一行或者同一列的其他像素信息作为补充,对于语义分割任务,contextual information仍然是稀疏的(sparse),因为语义分割更在意一个像素和它周围的一些像素的关系。针对这个问题,作者提出了recurrent criss-cross attention module,来建模一个像素和全局所有像素的关系。方式是通过重复criss-cross attention module来实现的。这些module也是参数shared的。

同样是建模空间维度的pixel-wise contextual information,CCNet的计算量相较于self attention,可小太多了。一个CC module,要处理的是一个像素点和同一行、同一列一共(H+W-1)这么多的像素,那么应用在所有像素上,计算量就是O(HW(H+W-1))。回顾DANet的空间注意力分支(position attention module),每一个像素就要和(HW)个像素建模之间的联系,应用在所有相素,计算量就是O(HW*(H*W))。
通过递归的方式用CC module,可以对一个像素捕捉到全局的contextual information,提到了语义分割任务的效果。
个人看法,简单且有效的,就是极其优秀的方法,CCNet就属于这一类方法。
 

在这里插入图片描述

1.首先一个原图送进backbone,这个backbone是修改过的,把最后两个stage的stride改为1,同时应用空洞卷积来增大感受野。得到的特征图是原图的1/8.

2.然后经过1*1的卷积降维。得到H

3.H经过一个criss-cross attention module 得到H ′ 这个时候,H’中的每个位置都捕捉到了和u在同一行或者同一列的context information

4.H’经过一个相同结构、相同参数的cc module,得到了H’’。在H‘’中的每个位置,捕捉的是全局性的contextual information
5..最后经过一个分割层输出最后的预测结果。
在这里插入图片描述

 

之前改进增加了很多注意力机制的方法,包括比较常规的SE、CBAM等,本文加入CrissCrossAttention注意力机制,该注意力机制为应用在语义分割中的模块,用于可以让网络更加关注待检测目标,提高检测效果

基本原理:

       语义分割的Criss-Cross网络(CCNet)的细节。我们首先介绍了CCNet的总体框架。然后,将介绍在水平和垂直方向捕获上下文信息的2D交叉注意力模块。为了获取密集的全局上下文信息,我们建议对交叉注意力模块采用循环操作。为了进一步改进RCCA,我们引入了判别损失函数来驱动RCCA学习类别一致性特征。最后,我们提出了同时利用时间和空间上下文信息的三维交叉注意模块。

2. yolov5添加方法:

2.1common.py构建CrissCrossAttention模块

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn import Softmax
  5. def INF(B,H,W):
  6. return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
  7. class CrissCrossAttention(nn.Module):
  8. """ Criss-Cross Attention Module"""
  9. def __init__(self, in_dim):
  10. super(CrissCrossAttention,self).__init__()
  11. self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
  12. self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
  13. self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
  14. self.softmax = Softmax(dim=3)
  15. self.INF = INF
  16. self.gamma = nn.Parameter(torch.zeros(1))
  17. def forward(self, x):
  18. m_batchsize, _, height, width = x.size()
  19. proj_query = self.query_conv(x)
  20. proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
  21. proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
  22. proj_key = self.key_conv(x)
  23. proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
  24. proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
  25. proj_value = self.value_conv(x)
  26. proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
  27. proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
  28. energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
  29. energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
  30. concate = self.softmax(torch.cat([energy_H, energy_W], 3))
  31. att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
  32. #print(concate)
  33. #print(att_H)
  34. att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
  35. out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
  36. out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
  37. #print(out_H.size(),out_W.size())
  38. return self.gamma*(out_H + out_W) + x

2.2yolo.py中注册 CrissCrossAttention模块

  1. elif m is CrissCrossAttention:
  2. c1, c2 = ch[f], args[0]
  3. if c2 != no:
  4. c2 = make_divisible(c2 * gw, 8)
  5. args = [c1, *args[1:]]

2.3修改yaml文件。

  1. # YOLOAir
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/398093
    推荐阅读
    相关标签