当前位置:   article > 正文

3.deeplabv3+的深层网络结构的实现_deeplabv3网络结构

deeplabv3网络结构

         在第一篇文章中我们提到“在encoder部分,主要包括了backbone(DCNN)、ASPP两大部分”,在这里的backbone就是mobilenetv2网络结构和xception网络结构,而ASPP结构就是深层网络结构,其网络结构如下:

        ASPP网络结构的原理其实很简单,可以看博文1.deeplabv3+网络结构及原理-CSDN博客,该博文有介绍。以上网络结构里的rate表示空洞卷积核的大小,显然,该网络结构总共5层卷积处理,之后再将不同的层用concat堆叠,最后再用1x1的卷积核整合特征,转换为图片中绿色的层。

       下面深层网络结构的代码如下:

  1. #-----------------------------------------#
  2. # ASPP特征提取模块
  3. # 利用不同膨胀率的膨胀卷积进行特征提取
  4. #-----------------------------------------#
  5. class ASPP(nn.Module):
  6. def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
  7. super(ASPP, self).__init__()
  8. self.branch1 = nn.Sequential(
  9. nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
  10. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  11. nn.ReLU(inplace=True),
  12. )
  13. self.branch2 = nn.Sequential(
  14. nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True),
  15. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  16. nn.ReLU(inplace=True),
  17. )
  18. self.branch3 = nn.Sequential(
  19. nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),
  20. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  21. nn.ReLU(inplace=True),
  22. )
  23. self.branch4 = nn.Sequential(
  24. nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),
  25. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  26. nn.ReLU(inplace=True),
  27. )
  28. self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
  29. self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
  30. self.branch5_relu = nn.ReLU(inplace=True)
  31. self.conv_cat = nn.Sequential(
  32. nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
  33. nn.BatchNorm2d(dim_out, momentum=bn_mom),
  34. nn.ReLU(inplace=True),
  35. )
  36. def forward(self, x):
  37. [b, c, row, col] = x.size()
  38. # -----------------------------------------#
  39. # 一共五个分支
  40. # -----------------------------------------#
  41. conv1x1 = self.branch1(x)
  42. conv3x3_1 = self.branch2(x)
  43. conv3x3_2 = self.branch3(x)
  44. conv3x3_3 = self.branch4(x)
  45. # -----------------------------------------#
  46. # 第五个分支,全局平均池化+卷积
  47. # -----------------------------------------#
  48. global_feature = torch.mean(x, 2, True)
  49. global_feature = torch.mean(global_feature, 3, True)
  50. global_feature = self.branch5_conv(global_feature)
  51. global_feature = self.branch5_bn(global_feature)
  52. global_feature = self.branch5_relu(global_feature)
  53. global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
  54. # -----------------------------------------#
  55. # 将五个分支的内容堆叠起来
  56. # 然后1x1卷积整合特征。
  57. # -----------------------------------------#
  58. feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
  59. result = self.conv_cat(feature_cat)
  60. return result

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/678186
推荐阅读
相关标签
  

闽ICP备14008679号