当前位置:   article > 正文

Attention Guided Global Enhancement and LocalRefinement Network for Semantic Segmentation(结合代码)_attention guided global enhancement and local refi

attention guided global enhancement and local refinement network for semanti

摘要:

编码器-解码器结构作为轻量级的语义分割网络被广泛应用。然而,与设计良好的fcn模型相比,它的性能有限,存在两个主要问题。首先,在解码器中常用的上采样方法,如插值和反卷积,会受到局部接收场的影响,无法对全局上下文进行编码。其次,由于早期编码器层语义概念的不足,低级特征可能通过跳跃式连接给网络解码器带来噪声。为了解决这些问题,提出了一种全局增强方法,从高级特征映射中聚合全局信息,并自适应地将其分布到不同的解码器层,缓解了上采样过程中全局上下文的不足。此外,我们还开发了一个局部细化模块,利用解码器特征作为语义导向,在两者融合之前细化有噪声的编码器特征(解码器特征和编码器特征)。然后,将两种方法集成到一个上下文融合块中,在此基础上,精心设计了一种新的注意引导的全局增强和局部细化网络(AGLN)。

介绍

目前大多数的语义分割方法都是基于完全卷积网络[4]的。然而,连续的下采样操作,如池化和跨卷积,导致原始输入图像的空间细节显著下降。为了缓解这一问题,大多数最先进的方法[5]-[13]放弃了一些降采样操作以保持相对高分辨率的特征映射,并利用扩张卷积来扩大接收域,称为基于扩张fcn的模型。与此相反,另一种方法是在分类网络的顶部精心设计一个解码器(即编码器),逐步重建一个高分辨率的分割图,形成一个“u形”的编码器- 解码器架构[14]-[16]。得益于低分辨率的中间特征映射,编码器-解码器网络通常比广泛使用的扩张fcn模型需要更低的计算和存储成本。然而,编码器-解码器模型的性能受到两个主要限制。

首先,现有解码器主要利用双线性插值或反卷积对高/低分辨率特征图进行上采样,以匹配像素级的监督。然而,对[17]、[18]的研究表明,这些上采样方法对像素级预测的准确恢复能力有限。具体来说,插值和反褶积都只考虑有限的邻域,不能捕捉远距离的上下文信息,这意味着上采样特征图的每个位置的特征表示都是从有限的接收域恢复的。因此,在上采样过程中全局上下文的缺失导致了大规模对象的常见误分类。(一句话:上采样的方法不行,会导致大目标分类出现错误)

此外,编码器层的低级特征通过跳过连接传递给解码器。然而,这些编码器特征映射需要进一步细化,因为它们不能提供足够细粒度的细节来实现准确的分割边界和小对象[21],[22]的识别。具体地说,由于接受域有限,早期的编码器卷积(即滑动窗口)遍历图像的所有局部区域以捕获局部上下文。这些上下文对于后面卷积层中的语义概念的形成可能是至关重要的。然而,大量混乱的局部表示(即噪声)不利于语义边界的识别。一些编码器解码器网络通过简单的添加或拼接操作[14]、[20]、[23],将低级和高级特征结合起来,淹没在相当大的噪声中,进而导致不同对象之间的混淆以及对小目标的忽略。 (一句话:简单的跳跃连接不能改善网络对小目标边界的分割)

为了应对这些挑战,本文提出了一种“注意力引导的全局增强和局部细化网络”(AGLN)框架。针对特征上采样过程中全局信息不足的问题,提出一种全局增强方法,首先从高级特征中选择性地聚合全局上下文,然后自适应地将其分布到上采样特征映射中。通过这种方式,将全局类别线索聚合并传递到解码器的各个阶段,弥补了上采样过程中全局上下文的不足,生成全局增强解码器特征图。此外,利用全局增强方法的输出作为语义引导,设计了一个局部细化模块,从编码器中仔细细化低级特征,滤除噪声,生成更多信息的局部细节。然后,将这两种方法合并到FPN基线中。所得到的AGLN可以提供全局增强的解码器特征和局部精细的空间细节,实现在不同解码器层中更有效的上下文融合。

我们的主要贡献可概括如下:

  1. 我们提出了一种全局增强方法,从高级特征中聚合全局上下文,并将其传递到不同的解码器层,以弥补上采样过程中全局信息的缺失。

  2. 我们引入了一个局部细化模块,提供语义引导来细化编码器特征,在低、高层上下文融合前过滤掉噪声。

  3. 我们提出的AGLN将上述两种方法合并到FPN基线中。大量实验表明,AGLN在PASCAL Context (56.23% mIOU)上取得了最先进的结果,在ADE20K (45.38% mIOU)和PASCAL VOC 2012 (84.9% mIOU)上取得了具有竞争力的结果。

方法:(这里结合代码更易理解,或者直接看代码)

在本节中,我们首先说明编码器-解码器分割模型的主要缺点,然后提出两个相应的解决方案来克服这些缺点。具体来说,一方面,提出了一种由语义聚合块(SAB)语义分布模块(SDM)组成的全局增强方法,以弥补上采样过程中不具备的全局特征;另一方面,设计了局部细化模块(LRM),从低阶编码器特征中滤除噪声,生成更多信息丰富的局部细节。在此基础上,构建了一种新的AGLN,在译码器进行特征融合之前增强全局上下文和细化局部纹理,与原始的编码器-译码器基准相比性能得到了提高。

a.解码器特征的全局增强:

1)在上采样过程中缺乏全局特征:本研究旨在突破现有上采样方法的局限性,寻求一种改进高分辨率分割图重建的方法。为此,一个理想的上采样过程增强方法应该能够(1)聚合全局空间中的上下文信息,(2)自适应地更新特征地图每个特定位置的表示,(3)保持计算效率。为此,提出了全局增强方法。

2)全局增强:在我们的方法中,保留了原有的基线上采样算子(即双线性插值)。此外,我们建议从低分辨率/高分辨率特征中聚合全局特征,并将它们附加到上采样解码器特征映射中,从而实现上采样过程中全局上下文的恢复。

图中三个block建议看代码,论文阅读起来并不清晰,其实就是利用了注意力机制。

如果不了解自注意力机制相关知识,建议先进行了解,可以看我的另一个博客(附带代码)

Semantic Aggregation Block(SAB语义聚合块):

  1. class SAB(nn.Module):
  2. """
  3. Semantic Aggregation Block:
  4. Aggregate global semantic descriptors from the encoder output.
  5. Params:
  6. c_in: input channels, same as fpn_dim(256)
  7. c_feat: feature channels, C in the paper (default as 256).
  8. c_atten: number of semantic descriptors, N in the paper (1, 64, 128, 256).
  9. """
  10. def __init__(self, c_in, c_feat, c_atten):
  11. super(SAB, self).__init__()
  12. self.c_feat = c_feat
  13. self.c_atten = c_atten
  14. self.conv_feat = nn.Conv2d(c_in, c_feat, kernel_size=1)
  15. self.conv_atten = nn.Conv2d(c_in, c_atten, kernel_size=1)
  16. def forward(self, input: torch.Tensor):
  17. b, c, h, w = input.size()
  18. feat = self.conv_feat(input).view(b, self.c_feat, -1) # feature map
  19. atten = self.conv_atten(input).view(b, self.c_atten, -1) # attention map
  20. atten = F.softmax(atten, dim=-1)
  21. descriptors = torch.bmm(feat, atten.permute(0, 2, 1)) # (c_feat, c_atten)
  22. return descriptors

这里input是图中的X,分别进行两种卷积最后输出D

Semantic Distribution Module: 

  1. class SDM(nn.Module):
  2. """
  3. Semantic Distribution Module:
  4. Distribute global semantic descriptors to each stage of decoder.
  5. Params:
  6. c_atten: number of semantic descriptors, N in the paper.
  7. c_de: decoder channels
  8. """
  9. def __init__(self, c_atten, c_de):
  10. super(SDM, self).__init__()
  11. self.c_atten = c_atten
  12. self.conv_de = nn.Conv2d(c_de, c_atten, kernel_size=1)
  13. self.out_conv = nn.Conv2d(c_de, c_de, kernel_size=1)
  14. def forward(self, descriptors: torch.Tensor, input_de: torch.Tensor):
  15. b, c, h, w = input_de.size()
  16. atten_vectors = F.softmax(self.conv_de(input_de), dim=1)
  17. output = descriptors.matmul(atten_vectors.view(b, self.c_atten, -1)).view(b, -1, h, w)
  18. return self.out_conv(output)

descriptors是上一个模块的输出D,input_de是A。进行了如图所示的操作,但似乎少了Element-wise Sum操作。

B.编码器特征的局部细化:

1)背景融合中的噪声编码器特征:在上下文融合中,由浅层特征层编码的低级视觉元素是准确预测目标边界和细节的关键。然而,由于早期编码器层的接收域有限,主要特征通常从整个图像中捕获大量低层纹理,模糊了目标和背景之间的边界。在上下文融合中,不必要的纹理对分割边界的识别贡献较小,严重影响对精细物体的分类,可以视为噪声。与低层次的编码器特征相比,解码器的表示编码了更多的语义信息,对不同的对象有相对明显的区分,甚至可以看作是粗分割图。因此,利用解码器的特征作为语义引导来捕捉有价值的空间细节并过滤噪声是合理的。在下一节中,我们将从信道维和空间维两个方面详细阐述编码器噪声特征的问题并提出相应的解决方案。

2)局部细化:局部细化模块被设计用于在通道(通道重采样)和空间(空间门控)维度上细化编码器特征。

        通道重采样:大量的低层特征通道不提供对分割有益的信息,需要在上下文融合前进行过滤。具体来说,从编码器特征的可视化(图5(b)和图5(d)的前两行)可以观察到,大部分通道(由绿色框标记)没有边界或分割目标的大致形状。这些信道对于物体边界的识别是无用的,可以看作是纯噪声。只有少数通道描述了清晰的结构细节,真正有助于分割。为了滤除编码器特征的噪声信道,采用通道注意机制在通道维度上对编码器特征映射进行重采样,即通道重采样。具体来说,建立一个注意模块来建模噪声编码器特征和语义丰富的解码器特征之间的信道依赖关系。在我们的设计中,我们强调了编码器特征与解码器特征高度相关的通道,从而提高了特定语义概念的表示。相反,编码器特征中不区分不同对象的通道将有很高的概率被过滤掉。实现细节如图2(a)的红框(“CR”表示通道重采样)所示。给定编码器特征 B ∈ R^(C×kH×kW), (k = 2, 4, 8),增强的解码器特征图E ∈ R^(C×kH×kW),(语义分布模块的输出),通过下面的公式,实现通道注意力以获得重采样的编码器特征图F,F = B × softmax(B × E^T),其中应用softmax算子来获得通道注意图。

        空间门控:低级特征通常在没有类别意识的情况下编码局部细节和纹理。一方面,目标区域外过多的纹理表示会模糊对象边界。另一方面,对象内的细节特征可能导致其他区域的错误分类。我们的见解是利用语义描述符图M作为空间门控图,从感兴趣的区域“裁剪”有价值的细节,从目标区域中丢弃无用甚至有害的纹理。选择语义描述符图是因为它具有很强的语义一致性,与解码器特征图相比,它更适合于生成完整的语义区域。如图2(a)的红框(“SG”表示空间门控)所示,在语义描述符图M上执行sigmoid层,以获得空间门控图,然后将该空间门控图与重采样的编码器特征逐元素相乘。

  1. class LRM(nn.Module):
  2. """
  3. Local Refinement Module: including channel resampling and spatial gating.
  4. Params:
  5. c_en: encoder channels
  6. c_de: decoder channels
  7. """
  8. def __init__(self, c_en, c_de):
  9. super(LRM, self).__init__()
  10. self.c_en = c_en
  11. self.c_de = c_de
  12. def forward(self, input_en: torch.Tensor, input_de: torch.Tensor, gate_map):
  13. b, c, h, w = input_de.size()
  14. input_en = input_en.view(b, self.c_en, -1)
  15. # Channel Resampling
  16. energy = input_de.view(b, self.c_de, -1).matmul(input_en.transpose(-1, -2))
  17. energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(
  18. energy) - energy # Prevent loss divergence during training
  19. channel_attention_map = torch.softmax(energy_new, dim=-1)
  20. input_en = channel_attention_map.matmul(input_en).view(b, -1, h, w) # channel_attention_feat
  21. # Spatial Gating
  22. gate_map = torch.sigmoid(gate_map)
  23. input_en = input_en.mul(gate_map)
  24. return input_en
  1. class CFB(nn.Module):
  2. """
  3. Context Fusion Block: including SDM and LRM.
  4. Params:
  5. c_atten: number of semantic descriptors, N in the paper.
  6. """
  7. def __init__(self, fpn_dim=256, c_atten=256, norm_layer=None, ):
  8. super(CFB, self).__init__()
  9. self.sdm = SDM(c_atten, fpn_dim)
  10. self.lrm = LRM(fpn_dim, fpn_dim)
  11. self.conv_fusion = nn.Sequential(
  12. nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1, bias=False),
  13. # nn.Conv2d(3 * fpn_dim, fpn_dim, kernel_size=3, padding=1, bias=False),
  14. # DepthwiseConv(3 * fpn_dim, fpn_dim),
  15. norm_layer(fpn_dim),
  16. nn.ReLU(inplace=True),
  17. )
  18. self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
  19. self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True)
  20. self.beta = nn.Parameter(torch.ones(1), requires_grad=True)
  21. def forward(self, input_en: torch.Tensor, input_de: torch.Tensor, global_descripitors: torch.Tensor):
  22. feat_global = self.sdm(global_descripitors, input_de)
  23. feat_local = self.gamma * self.lrm(input_en, input_de, feat_global) + input_en
  24. # add fusion
  25. return self.conv_fusion(input_de + self.alpha * feat_global + self.beta * feat_local)
  26. # concat fusion
  27. # return self.conv_fusion(torch.cat((input_de, self.beta * feat_global, self.gamma * feat_local), dim=1))

 Context Fusion Block

通过代码可以比较容易地理解作者的网络结构和思想。后续的实验暂时不更新。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/366882
推荐阅读
相关标签
  

闽ICP备14008679号