当前位置:   article > 正文

一个交叉网络(CCNet),有效和高效的方式获取全图像上下文信息_ccnet 代码

ccnet 代码

论文《CCNet: Criss-Cross Attention for Semantic Segmentation》提出了一种非常有效和高效的获取全图像上下文信息的交叉网络(CCNet)。具体来说,对于每个像素,一个新的交叉关注模块收集其交叉路径上所有像素的上下文信息。通过进一步的循环操作,每个像素最终可以捕获整个图像的依赖关系。此外,论文还提出了一种类别一致性损失来强制交叉注意模块产生更多的判别特征。

2维CCA模块结构图如下所示:

3维CCA结构图如下所示:

 CCA应用实例如下:

 CCA代码实现如下:

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn import Softmax
  5. def INF(B,H,W):
  6. return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
  7. class CrissCrossAttention(nn.Module):
  8. """ Criss-Cross Attention Module"""
  9. def __init__(self, in_dim):
  10. super(CrissCrossAttention, self).__init__()
  11. self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
  12. self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
  13. self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
  14. self.softmax = nn.Softmax(dim=3)
  15. self.INF = INF
  16. self.gamma = nn.Parameter(torch.zeros(1))
  17. def forward(self, x):
  18. m_batchsize, _, height, width = x.size()
  19. proj_query = self.query_conv(x)
  20. proj_query_H = proj_query.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height).permute(0, 2, 1)
  21. proj_query_W = proj_query.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width).permute(0, 2, 1)
  22. proj_key = self.key_conv(x)
  23. proj_key_H = proj_key.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)
  24. proj_key_W = proj_key.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)
  25. proj_value = self.value_conv(x)
  26. proj_value_H = proj_value.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)
  27. proj_value_W = proj_value.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)
  28. energy_H = (torch.bmm(proj_query_H, proj_key_H) + self.INF(m_batchsize, height, width)).view(m_batchsize, width,height,height).permute(0,2,1,3)
  29. energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize, height, width, width)
  30. concate = self.softmax(torch.cat([energy_H, energy_W], 3))
  31. att_H = concate[:, :, :, 0:height].permute(0, 2, 1, 3).contiguous().view(m_batchsize * width, height, height)
  32. # print(concate)
  33. # print(att_H)
  34. att_W = concate[:, :, :, height:height + width].contiguous().view(m_batchsize * height, width, width)
  35. out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize, width, -1, height).permute(0, 2, 3, 1)
  36. out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize, height, -1, width).permute(0, 2, 1, 3)
  37. # print(out_H.size(),out_W.size())
  38. return self.gamma * (out_H + out_W) + x
  39. if __name__ == '__main__':
  40. input = torch.randn(6, 8, 256, 256)
  41. cca = CrissCrossAttention(in_dim=8)
  42. output = cca(input)
  43. print(output.shape) # [6, 8, 256, 256]

 在语义分割中使用CCA注意力模块构建的CCNet网络结构示意图如下:

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号