当前位置:   article > 正文

DBNet++(TPAMI) 原理与代码解析

dbnet++

paper:Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion

code1:https://github.com/MhLiao/DB

code2:https://github.com/open-mmlab/mmocr

本文的创新点

本文是对DBNet的改进,关于DBNet的介绍具体可见场景文本检测算法 可微分二值化DBNet原理与代码解析,本文新提出了一种自适应尺度融合模块Adaptive Scale Fusion(ASF)module来自适应地融合多尺度的特征,将ASF应用于分割网络,显著地增强了其检测不同尺度文本实例的能力。

方法介绍

DBNet++的完整结构如下图所示

其中,在FPN的多层输出和最终的预测特征图之间加入了ASF module。

ASF的完整结构如下图所示

FPN的输出为 XRN×C×H×W={Xi}i=0N1,其中 N=4 表示FPN的4个不同尺度的输出特征,通过插值得到了一致的spatial size。首先将 X 沿通道concatenate然后通过一个 3×3 的卷积层得到中间特征 SRC×H×W。然后,S 经过一个空间注意力模块spatial attention module得到注意力权重 ARN×H×W。接着,权重 A 沿通道维度均分为 N 份,并与相应的特征加权相乘得到最终的融合特征 FRN×C×H×W

scale attention的完整过程定义如下

代码解析

这里以mmocr的实现为例,注意在文章中作者提出的ASF是一个spatial attention模块,但在官方实现https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py中,作者给出了三种不同注意力机制的实现,除了文章中提到的spatial attention,还有channel attention以及两者结合的spatial-channel attention。MMOCR只移植了spatial-channel attention的实现即ScaleChannelSpatialAttention,具体如下

  1. class ScaleChannelSpatialAttention(BaseModule):
  2. """Spatial Attention module in Real-Time Scene Text Detection with
  3. Differentiable Binarization and Adaptive Scale Fusion.
  4. This was partially adapted from https://github.com/MhLiao/DB
  5. Args:
  6. in_channels (int): A numbers of input channels.
  7. c_wise_channels (int): Number of channel-wise attention channels.
  8. out_channels (int): Number of output channels.
  9. init_cfg (dict or list[dict], optional): Initialization configs.
  10. """
  11. def __init__(
  12. self,
  13. in_channels: int, # 256
  14. c_wise_channels: int, # 64
  15. out_channels: int, # 4
  16. init_cfg: Optional[Union[Dict, List[Dict]]] = [
  17. dict(type='Kaiming', layer='Conv', bias=0)
  18. ]
  19. ) -> None:
  20. super().__init__(init_cfg=init_cfg)
  21. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  22. # Channel Wise
  23. self.channel_wise = Sequential(
  24. ConvModule(
  25. in_channels,
  26. c_wise_channels,
  27. 1,
  28. bias=False,
  29. conv_cfg=None,
  30. norm_cfg=None,
  31. act_cfg=dict(type='ReLU'),
  32. inplace=False),
  33. ConvModule(
  34. c_wise_channels,
  35. in_channels,
  36. 1,
  37. bias=False,
  38. conv_cfg=None,
  39. norm_cfg=None,
  40. act_cfg=dict(type='Sigmoid'),
  41. inplace=False))
  42. # Spatial Wise
  43. self.spatial_wise = Sequential(
  44. ConvModule(
  45. 1,
  46. 1,
  47. 3,
  48. padding=1,
  49. bias=False,
  50. conv_cfg=None,
  51. norm_cfg=None,
  52. act_cfg=dict(type='ReLU'),
  53. inplace=False),
  54. ConvModule(
  55. 1,
  56. 1,
  57. 1,
  58. bias=False,
  59. conv_cfg=None,
  60. norm_cfg=None,
  61. act_cfg=dict(type='Sigmoid'),
  62. inplace=False))
  63. # Attention Wise
  64. self.attention_wise = ConvModule(
  65. in_channels,
  66. out_channels,
  67. 1,
  68. bias=False,
  69. conv_cfg=None,
  70. norm_cfg=None,
  71. act_cfg=dict(type='Sigmoid'),
  72. inplace=False)
  73. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  74. """
  75. Args:
  76. inputs (Tensor): A concat FPN feature tensor that has the shape of
  77. :math:`(N, C, H, W)`.
  78. Returns:
  79. Tensor: An attention map of shape :math:`(N, C_{out}, H, W)`
  80. where :math:`C_{out}` is ``out_channels``.
  81. """
  82. # (4,256,160,160)
  83. out = self.avg_pool(inputs) # (4,256,1,1)
  84. out = self.channel_wise(out) # (4,256,1,1)
  85. out = out + inputs # (4,256,160,160)
  86. inputs = torch.mean(out, dim=1, keepdim=True) # (4,1,160,160)
  87. out = self.spatial_wise(inputs) + out # (4,1,160,160)+(4,256,160,160)->(4,256,160,160)
  88. out = self.attention_wise(out) # (4,4,160,160)
  89. return out

这里设batch_size=4,input_size=(640, 640),FPN的4层输出经过上采样后得到统一大小的feature map,即列表[(4,64,160,160),(4,64,160,160),(4,64,160,160),(4,64,160,160)],然后沿通道拼接得到shape=(4,256,160,160)的输出,然后经过一个3x3的卷积层输出shape不变还是(4,256,160,160)得到ASF模块的输入。

首先经过全局平均池化得到(4,256,1,1)的输出,通道注意力模块self.channel_wise是一个两层卷积conv1x1-64-ReLU-conv1x1-256-Sigmoid得到大小不变的输出即通道注意力的权重,然后与原始输入相加。接着沿通道取均值,接着经过空间注意力模块即self.spatial_wise,它也是两层卷积conv3x3-1-ReLU-conv1x1-1-Sigmoid得到空间注意力的权重再与输入相加,最后经过conv1x1-4-Sigmoidself.attention_wise得到ASF模块的输出(4,4,160,160)

然后将ASF模块输出的4层注意力权重与原始FPN对应的4层输出进行加权相乘,最后再沿通道拼接得到最终输出。

  1. for i, out in enumerate(outs):
  2. enhanced_feature.append(attention[:, i:i + 1] * outs[i])
  3. out = torch.cat(enhanced_feature, dim=1)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/489805
推荐阅读
相关标签
  

闽ICP备14008679号