当前位置:   article > 正文

【计算机视觉】详解 Non-local 与 SENet、CBAM 模块融合:GCNet、DANet (视觉注意力机制 (三))_non-local 模块

non-local 模块

绪论

视觉注意力机制 (一) 阐述了视觉应用中的 Self-attention 机制及其应用 —— Non-local 网络模块;视觉注意力机制 (二)  主要分析了视觉注意力机制在分类网络中的应用 —— SENet、CBAM、SKNet 。它们构成了视觉注意力机制中的基本模块,本节将主要介绍融合 Non-local 模块和 SENet 模块的全局上下文建模网络(Global Context Network,简称 GCNet),以及 Non-local 模块与 CBAM 模块融合变形在语义分割中的应用 —— 双重注意力网络 DANet

前面的文章有详细讨论 Non-local 模块,Non-local Network(NLNet)使用自注意力机制来建模远程依赖。对于每个查询点(query position),NLNet 首先计算查询点与所有点之间的成对关系以得到注意力图,然后通过加权和的方式聚合所有点的特征,从而得到与此查询点相关的全局特征,最终再分别将全局特征加到每个查询点的特征中,完成远程依赖的建模过程。但是,NLNet存在着计算量大的问题。

SENet 模块中提出的 SE block 是使用全局上下文对不同通道进行权值重标定,对通道依赖进行调整。但是采用这种方法,并没有充分利用全局上下文信息

CBAM 将注意力过程分为两个独立的部分,通道注意力模块 (look what) 空间注意力模块 (look where)。不仅可以节约参数和计算力,而且保证了其可以作为即插即用的模块集成到现有的网络架构中去。但是 CBAM 需要手工设计 pooling,多层感知器等复杂操作。

下面,将分别介绍 GCNet 和 DANet


一、Global Context Network (GCNet)

论文地址:https://arxiv.org/search/?query=GCNet&searchtype=all&source=header

代码地址:https://github.com/xvjiarui/GCNet

为了捕获长距离依赖关系,产生了两类方法:

  • 采用自注意力机制来建模query对的关系。
  • 对query-independent(可以理解为无query依赖)的全局上下文建模。

NLNet 就是采用自注意力机制来建模像素对关系。然而NLNet对于每一个位置学习不受位置依赖的 attention map,造成了大量的计算浪费。

SENet 用全局上下文对不同通道进行权值重标定,来调整通道依赖。然而,采用权值重标定的特征融合,不能充分利用全局上下文。

GCNet 作者对 NLNet 进行试验,选择COCO数据集中的6幅图,对于不同的查询点(query point)分别对 Attention maps进行可视化,得到以下结果:

可以看出,对于不同的查询点,其attention map是几乎一致的,这说明NLNet学习到的是独立于查询的依赖(query-independent dependency),这说明虽然NLNet想要对每一个位置进行特定的全局上下文计算,但是可视化结果以及实验数据证明,non-local network 的全局上下文在不同位置几乎是相同的,这表明学习到了无位置依赖的全局上下文。

基于以上发现,作者希望能够减少不必要的计算量,降低计算,并结合SENet设计,提出了GCNet融合了两者的优点,既能够有用 NLNet 的全局上下文建模能力,又能够像 SENet 一样轻量

简化的 Non-local 模块

preview

作者通过计算一个全局的 attention map 来简化 non-local block,并且对所有位置共享这个全局 attention map。简化版的non-local block 定义为: 

                       

为了进一步减少简化版 non-local block 的计算量,将 Wv 移到 attention pooling 的外面,表示为:

                        

不同于原始的 non-local block,简化版 non-local block 的第二项是不受位置依赖的,所有位置共享这一项。因此,作者直接将全局上下文建模为所有位置特征的加权平均值,然后聚集全局上下文特征到每个位置的特征上

简化版的 non-local block 可以抽象为3个步骤:

  • 全局 attention pooling:采用 1x1 卷积 Wk 和 softmax 函数来获取 attention 权值,然后执行 attention pooling 来获得全局上下文特征。
  • 特征转换:采用 1x1 卷积 Wv 。
  • 特征聚合:采用相加操作将全局上下文特征聚合到每个位置的特征上。

SE 模块

前面讲过,SE block如下图所示,也可以抽象成3个步骤:

  • 全局平均池化用于上下文建模 (即squeeze operation)。
  • bottleneck transform 用于计算每个通道的重要程度 (即excitation operation)。
  • rescaling function 用于通道特征重标定 (即element-wise multiplication)。

GCNet 模块

作者提出了一种新的全局上下文建模框架,global context block(简写GCNet),即能够像Non-local block一样建立有效的长距离依赖,又能够像SE block一样省计算量。

GC block的3个步骤为:

  1. global attention pooling用于上下文建模
  2. bottleneck transform来捕获通道间依赖
  3. broadcast element-wise addition用于特征融合

在简化版的 non-local block 中,transform 模块有大量的参数。为了获得 SE block 轻量的优点,1x1 卷积用 bottleneck transform 模块来取代,能够显著的降低参数量(其中r是降低率)。因为两层 bottleneck transform 增加了优化难度,所以在 ReLU前面增加一个 layer normalization 层(降低优化难度且作为正则提高了泛化性)。

GC block 在 ResNet 中的使用位置是每两个 Stage 之间的连接部分,下边是 GC block 的官方实现 (基于 mmdetection 进行修改):

  1. import torch
  2. from torch import nn
  3. class ContextBlock(nn.Module):
  4. def __init__(self,inplanes,ratio,pooling_type='att',
  5. fusion_types=('channel_add', )):
  6. super(ContextBlock, self).__init__()
  7. valid_fusion_types = ['channel_add', 'channel_mul']
  8. assert pooling_type in ['avg', 'att']
  9. assert isinstance(fusion_types, (list, tuple))
  10. assert all([f in valid_fusion_types for f in fusion_types])
  11. assert len(fusion_types) > 0, 'at least one fusion should be used'
  12. self.inplanes = inplanes
  13. self.ratio = ratio
  14. self.planes = int(inplanes * ratio)
  15. self.pooling_type = pooling_type
  16. self.fusion_types = fusion_types
  17. if pooling_type == 'att':
  18. self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
  19. self.softmax = nn.Softmax(dim=2)
  20. else:
  21. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  22. if 'channel_add' in fusion_types:
  23. self.channel_add_conv = nn.Sequential(
  24. nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
  25. nn.LayerNorm([self.planes, 1, 1]),
  26. nn.ReLU(inplace=True), # yapf: disable
  27. nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
  28. else:
  29. self.channel_add_conv = None
  30. if 'channel_mul' in fusion_types:
  31. self.channel_mul_conv = nn.Sequential(
  32. nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
  33. nn.LayerNorm([self.planes, 1, 1]),
  34. nn.ReLU(inplace=True), # yapf: disable
  35. nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
  36. else:
  37. self.channel_mul_conv = None
  38. def spatial_pool(self, x):
  39. batch, channel, height, width = x.size()
  40. if self.pooling_type == 'att':
  41. input_x = x
  42. # [N, C, H * W]
  43. input_x = input_x.view(batch, channel, height * width)
  44. # [N, 1, C, H * W]
  45. input_x = input_x.unsqueeze(1)
  46. # [N, 1, H, W]
  47. context_mask = self.conv_mask(x)
  48. # [N, 1, H * W]
  49. context_mask = context_mask.view(batch, 1, height * width)
  50. # [N, 1, H * W]
  51. context_mask = self.softmax(context_mask)
  52. # [N, 1, H * W, 1]
  53. context_mask = context_mask.unsqueeze(-1)
  54. # [N, 1, C, 1]
  55. context = torch.matmul(input_x, context_mask)
  56. # [N, C, 1, 1]
  57. context = context.view(batch, channel, 1, 1)
  58. else:
  59. # [N, C, 1, 1]
  60. context = self.avg_pool(x)
  61. return context
  62. def forward(self, x):
  63. # [N, C, 1, 1]
  64. context = self.spatial_pool(x)
  65. out = x
  66. if self.channel_mul_conv is not None:
  67. # [N, C, 1, 1]
  68. channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
  69. out = out * channel_mul_term
  70. if self.channel_add_conv is not None:
  71. # [N, C, 1, 1]
  72. channel_add_term = self.channel_add_conv(context)
  73. out = out + channel_add_term
  74. return out
  75. if __name__ == "__main__":
  76. in_tensor = torch.ones((12, 64, 128, 128))
  77. cb = ContextBlock(inplanes=64, ratio=1./16.,pooling_type='att')
  78. out_tensor = cb(in_tensor)
  79. print(in_tensor.shape)
  80. print(out_tensor.shape)

二、Dual Attention Network for Scene Segmentation(DANet)

论文地址:https://arxiv.org/pdf/1809.02983.pdf​arxiv.org

代码地址:https://github.com/junfu1115/DANet​github.com

DANet 是一种经典的应用self-Attention的网络,它引入了一种自注意力机制来分别捕获空间维度和通道维度中的特征依赖关系。

场景分割需要预测出图像中的像素点属于某一目标类或场景类,其图像场景的复杂多样(光照,视角,尺度,遮挡等)对于场景的理解和像素点的判别造成很大困难。

主流场景分割方法大致可分为以下两种类型:一是通过使用多尺度特征融合的方式增强特别的表达,例如空间金字塔结构 (PSP,ASPP) 或者高层浅层特征融合 (RefineNet)。但是这些方式没有考虑到不同特征之间的关联依赖,而这对于场景的理解确实十分重要。另一是利用 RNN 网络构建特征长范围的特征关联,但这种关联往往受限于 RNN 的 long-term memorization。

双重注意网络(DANet)来自适应地集成局部特征和全局依赖。在传统的扩张 FCN 之上附加两种类型的注意力模块,分别模拟空间和通道维度中的语义相互依赖性。

preview

从其结构图中可以看到,它由两个并列的 attention module 组成,第一个得到的是特征图中任意两个位置的依赖关系,称为Position Attention Module(PAM);第二个是任意两个通道间的依赖关系,称为 Channel Attention Module(CAM)。

从其具体的模块中来看,PAM中 的 attention_map 的大小为 B×(W×H)×(W×H),而 CAM 中的 attention_map 大小为B×C×C,这就是 PAM 与 CAM 的区别,他们所代表的一个是任意两个位置之间的依赖关系,一个代表的是任意两个通道之间的依赖关系。

  • 位置注意力模块(PAM)通过所有位置处的特征的加权和来选择性地聚合每个位置的特征。无论距离如何,类似的特征都将彼此相关。
  • 通道注意力模块(CAM)通过整合所有通道映射之间的相关特征来选择性地强调存在相互依赖的通道映射。
  • 将两个注意模块的输出相加以进一步改进特征表示,这有助于更精确的分割结果。

位置注意力模块(PAM)

问题:传统FCNs生成的特征会导致对物体的错误分类。

解决:引入位置注意模块在局部特征上建立丰富的上下文关系,将更广泛的上下文信息编码为局部特征,进而增强他们的表示能力。

preview

位置注意力模块旨在利用任意两点特征之间的关联,来相互增强各自特征的表达

具体来说,首先计算出任意两点特征之间关联强度矩阵,即原始特征 A 经过卷积降维获得特征 B 和特征 C,然后改变特征维度 B 为 ((HxW)xC') 和 C 为 (C'x(HxW)) 然后矩阵乘积获得任意两点特征之间的关联强度矩 ((HxW)x(HxW))。然后经过 softmax 操作归一化获得每个位置对其他位置的 attention 图 S, 其中越相似的两点特征之间,其响应值越大。接着将 attention 图中响应值作为加权对特征 D 进行加权融合,这样对于各个位置的点,其通过 attention 图在全局空间中的融合相似特征。

通道注意力模块(CAM)

问题:每个high level特征的通道图都可以看作是一个特定于类的响应,通过挖掘通道图之间的相互依赖关系,可以突出相互依赖的特征图,提高特定语义的特征表示。

解决:建立一个通道注意力模块来显式地建模通道之间的依赖关系。

preview

通道注意力模块旨在通过建模通道之间的关联,增强通道下特定语义响应能力。

具体过程与位置注意力模块相似,不同的是在获得特征注意力图 X 时,是将任意两个通道特征进行维度变换和矩阵乘积,获得任意两个通道的关联强度,然后同样经过 softmax 操作获得的通道间的 attention 图。最后通过通道之间的 attention 图加权进行融合,使得各个通道之间能产生全局的关联,获得更强的语义响应的特征。

为了进一步获得全局依赖关系的特征,将两个模块的输出结果进行相加融合,获得最终的特征用于像素点的分类。

总的来说,DANet 网络主要思想是 CBAM 和 non-local 的融合变形。把 deep feature map 进行 spatial-wise self-attention,同时也进行 channel-wise self-attetnion,最后将两个结果进行 element-wise sum 融合。

在 CBAM 分别进行空间和通道 self-attention 的思想上,直接使用了 non-local 的自相关矩阵 Matmul 的形式进行运算,避免了 CBAM 手工设计 pooling,多层感知器等复杂操作。


文献来源:https://zhuanlan.zhihu.com/p/111143631

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/855357
推荐阅读
相关标签
  

闽ICP备14008679号