赞
踩
论文《CCNet: Criss-Cross Attention for Semantic Segmentation》提出了一种非常有效和高效的获取全图像上下文信息的交叉网络(CCNet)。具体来说,对于每个像素,一个新的交叉关注模块收集其交叉路径上所有像素的上下文信息。通过进一步的循环操作,每个像素最终可以捕获整个图像的依赖关系。此外,论文还提出了一种类别一致性损失来强制交叉注意模块产生更多的判别特征。
2维CCA模块结构图如下所示:
3维CCA结构图如下所示:
CCA应用实例如下:
CCA代码实现如下:
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn import Softmax
-
-
- def INF(B,H,W):
- return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
-
-
- class CrissCrossAttention(nn.Module):
- """ Criss-Cross Attention Module"""
-
- def __init__(self, in_dim):
- super(CrissCrossAttention, self).__init__()
- self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
- self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
- self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
- self.softmax = nn.Softmax(dim=3)
- self.INF = INF
- self.gamma = nn.Parameter(torch.zeros(1))
-
- def forward(self, x):
- m_batchsize, _, height, width = x.size()
- proj_query = self.query_conv(x)
- proj_query_H = proj_query.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height).permute(0, 2, 1)
- proj_query_W = proj_query.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width).permute(0, 2, 1)
- proj_key = self.key_conv(x)
- proj_key_H = proj_key.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)
- proj_key_W = proj_key.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)
- proj_value = self.value_conv(x)
- proj_value_H = proj_value.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)
- proj_value_W = proj_value.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)
- energy_H = (torch.bmm(proj_query_H, proj_key_H) + self.INF(m_batchsize, height, width)).view(m_batchsize, width,height,height).permute(0,2,1,3)
- energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize, height, width, width)
- concate = self.softmax(torch.cat([energy_H, energy_W], 3))
-
- att_H = concate[:, :, :, 0:height].permute(0, 2, 1, 3).contiguous().view(m_batchsize * width, height, height)
- # print(concate)
- # print(att_H)
- att_W = concate[:, :, :, height:height + width].contiguous().view(m_batchsize * height, width, width)
- out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize, width, -1, height).permute(0, 2, 3, 1)
- out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize, height, -1, width).permute(0, 2, 1, 3)
- # print(out_H.size(),out_W.size())
- return self.gamma * (out_H + out_W) + x
-
-
- if __name__ == '__main__':
- input = torch.randn(6, 8, 256, 256)
- cca = CrissCrossAttention(in_dim=8)
- output = cca(input)
- print(output.shape) # [6, 8, 256, 256]

在语义分割中使用CCA注意力模块构建的CCNet网络结构示意图如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。