当前位置:   article > 正文

解析DeepLabv3+的网络结构及代码【Pytorch版】

deeplabv3+

论文地址:https://arxiv.org/pdf/1802.02611.pdf

论文笔记:DeepLabV3+: 在DeepLabV3基础上引入了Decoder_deeplabv3+的decoder代码是哪段_叶舟的博客-CSDN博客

参考代码:https://github.com/yassouali/pytorch_segmentation/blob/master/models/deeplabv3_plus.py

目录

0、写在前面

1、网络结构

1.1、Encoder

1.2、Decoder

 2、代码解析

2.1、class DeepLab

2.2、backbone部分

2.2.1、ResNet作为backbone

2.2.2、Xception作为backbone

2.3、class ASPP

2.4、class Decoder

3、总结


0、写在前面

理解一个网络模型的最好方式,是结合其实现代码对照理解;理解一个网络的开源代码的最好方式,是结合网络结构图来一一对应着看。

DeepLabV3+的论文已经阅读完毕,阅读笔记可以参考这篇博客,现在开始从开源的代码上来详细了解网络实现的细节,主要参考了这个库中的实现。

1、网络结构

DeepLabV3+的网络结构如下图所示,主要为Encoder-Decoder结构。其中,Encoder为改进的DeepLabV3,Decoder为3+版本新提出的。

图1.  DeepLabV3+网络结构图
图1.  DeepLabV3+网络结构图

1.1、Encoder

在Encoder部分,主要包括了backbone(即:图1中的DCNN)、ASPP两大部分。

  • 其中backbone有两种网络结构:将layer4改为空洞卷积的Resnet系列、改进的Xception。从backbone出来的feature map分两部分:一部分是最后一层卷积输出的feature maps,另一部分是中间的低级特征的feature maps;backbone输出的第一部分送入ASPP模块,第二部分则送入Decoder模块。
  • ASPP模块接受backbone的第一部分输出作为输入,使用了四种不同膨胀率的空洞卷积块(包括卷积、BN、激活层)和一个全局平均池化块(包括池化、卷积、BN、激活层)得到一共五组feature maps,将其concat起来之后,经过一个1*1卷积块(包括卷积、BN、激活、dropout层),最后送入Decoder模块。

1.2、Decoder

在Decoder部分,接收来自backbone中间层的低级feature maps和来自ASPP模块的输出作为输入。

  • 首先,对低级feature maps使用1*1卷积进行通道降维,从256降到48(之所以需要降采样到48,是因为太多的通道会掩盖ASPP输出的feature maps的重要性,且实验验证48最佳);
  • 然后,对来自ASPP的feature maps进行插值上采样,得到与低级featuremaps尺寸相同的feature maps;
  • 接着,将通道降维的低级feature maps和线性插值上采样得到的feature maps使用concat拼接起来,并送入一组3*3卷积块进行处理;
  • 最后,再次进行线性插值上采样,得到与原图分辨率大小一样的预测图。

整个的网络结构即为上述几个部分,下面结合Pytorch实现的代码进行理解。

 2、代码解析

2.1、class DeepLab

该class为总的网络结构,我们从forward函数可以看出其整体的流程:输入x经过backbone得到16倍下采样的feature map1和低级feature map2;feature map1送入ASPP模块,得到结果,然后和feature map2一起送入Decoder模块;最后经过插值得到与原图大小相等的预测图。代码如下:

  1. '''
  2. -> Deeplab V3 +
  3. '''
  4. class DeepLab(BaseModel):
  5. def __init__(self, num_classes, in_channels=3, backbone='xception', pretrained=True,
  6. output_stride=16, freeze_bn=False, **_):
  7. super(DeepLab, self).__init__()
  8. assert ('xception' or 'resnet' in backbone)
  9. if 'resnet' in backbone:
  10. self.backbone = ResNet(in_channels=in_channels, output_stride=output_stride, pretrained=pretrained)
  11. low_level_channels = 256
  12. else:
  13. self.backbone = Xception(output_stride=output_stride, pretrained=pretrained)
  14. low_level_channels = 128
  15. self.ASSP = ASSP(in_channels=2048, output_stride=output_stride)
  16. self.decoder = Decoder(low_level_channels, num_classes)
  17. if freeze_bn: self.freeze_bn()
  18. def forward(self, x):
  19. H, W = x.size(2), x.size(3)
  20. x, low_level_features = self.backbone(x)
  21. x = self.ASSP(x)
  22. x = self.decoder(x, low_level_features)
  23. x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
  24. return x
  25. # Two functions to yield the parameters of the backbone
  26. # & Decoder / ASSP to use differentiable learning rates
  27. # FIXME: in xception, we use the parameters from xception and not aligned xception
  28. # better to have higher lr for this backbone
  29. def get_backbone_params(self):
  30. return self.backbone.parameters()
  31. def get_decoder_params(self):
  32. return chain(self.ASSP.parameters(), self.decoder.parameters())
  33. def freeze_bn(self):
  34. for module in self.modules():
  35. if isinstance(module, nn.BatchNorm2d): module.eval()

 需要注意的是:如果使用ResNet系列作为backbone,中间的低级feature map输出维度为256,如果使用Xception作为backbone,中间的低级feature map维度为128。不过,不管是256还是128,最终都要在送入Decoder后降采样到48通道。

2.2、backbone部分

上文提到,backbone分ResNet系列和Xception两种。

2.2.1、ResNet作为backbone

对于ResNet系列,一共有layer0~4,共五个layer。其中,前三个layers,也即layer0~layer2不变,仅针对layer3、layer4进行了改进,将普通卷积改为了空洞卷积。如果输出步幅(输入尺寸与输出feature map尺寸之比)为8,需要改动layer3和layer4;如果输出步幅为16,则仅改动layer4:

  1. if output_stride == 16: s3, s4, d3, d4 = (2, 1, 1, 2)
  2. elif output_stride == 8: s3, s4, d3, d4 = (1, 1, 2, 4)
  3. if output_stride == 8:
  4. for n, m in self.layer3.named_modules():
  5. if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'):
  6. m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3)
  7. elif 'conv2' in n:
  8. m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3)
  9. elif 'downsample.0' in n:
  10. m.stride = (s3, s3)
  11. for n, m in self.layer4.named_modules():
  12. if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'):
  13. m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4)
  14. elif 'conv2' in n:
  15. m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4)
  16. elif 'downsample.0' in n:
  17. m.stride = (s4, s4)

此外,中间的低级feature maps在ResNet系列中,是layer1的输出。

2.2.2、Xception作为backbone

如果以Xception作为backbone,则需要对Xception的中间流(Middle Flow)和出口流(Exit flow)进行改动:去掉原有的池化层,并将原有的卷积层替换为带有步长的可分离卷积,但是入口流(Entry Flow)不变:

  1. # Stride for block 3 (entry flow), and the dilation rates for middle flow and exit flow
  2. if output_stride == 16: b3_s, mf_d, ef_d = 2, 1, (1, 2)
  3. if output_stride == 8: b3_s, mf_d, ef_d = 1, 2, (2, 4)
  4. # Entry Flow
  5. self.conv1 = nn.Conv2d(in_channels, 32, 3, 2, padding=1, bias=False)
  6. self.bn1 = nn.BatchNorm2d(32)
  7. self.relu = nn.ReLU(inplace=True)
  8. self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1, bias=False)
  9. self.bn2 = nn.BatchNorm2d(64)
  10. self.block1 = Block(64, 128, stride=2, dilation=1, use_1st_relu=False)
  11. self.block2 = Block(128, 256, stride=2, dilation=1)
  12. self.block3 = Block(256, 728, stride=b3_s, dilation=1)
  13. # Middle Flow
  14. for i in range(16):
  15. exec(f'self.block{i+4} = Block(728, 728, stride=1, dilation=mf_d)')
  16. # Exit flow
  17. self.block20 = Block(728, 1024, stride=1, dilation=ef_d[0], exit_flow=True)
  18. self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=ef_d[1])
  19. self.bn3 = nn.BatchNorm2d(1536)
  20. self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=ef_d[1])
  21. self.bn4 = nn.BatchNorm2d(1536)
  22. self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=ef_d[1])
  23. self.bn5 = nn.BatchNorm2d(2048)

而中间的低级feature maps在Xception系列中,是Entry Flow中block1的输出。

2.3、class ASPP

从backbone出来的输出步幅为16的feature maps被送入了ASPP模块,在该模块中经过不同膨胀率的卷积块和一个全局信息提取块后,concat起来,最后经过一个1*1卷积块之后,即为ASPP模块的输出。

注意,这里之所以说是“块”,是因为其不单单包含一个操作,也包含了多个其他的操作,如BN、RELU、Dropout等,上文的1.1节等地方均有类似描述。

如ASPP的不同膨胀率的分支定义如下:

  1. def assp_branch(in_channels, out_channles, kernel_size, dilation):
  2. padding = 0 if kernel_size == 1 else dilation
  3. return nn.Sequential(
  4. nn.Conv2d(in_channels, out_channles, kernel_size, padding=padding, dilation=dilation, bias=False),
  5. nn.BatchNorm2d(out_channles),
  6. nn.ReLU(inplace=True))

全局信息提取块定义如下:

  1. self.avg_pool = nn.Sequential(
  2. nn.AdaptiveAvgPool2d((1, 1)),
  3. nn.Conv2d(in_channels, 256, 1, bias=False),
  4. nn.BatchNorm2d(256),
  5. nn.ReLU(inplace=True))

ASPP类定义的完整代码如下:

  1. class ASSP(nn.Module):
  2. def __init__(self, in_channels, output_stride):
  3. super(ASSP, self).__init__()
  4. assert output_stride in [8, 16], 'Only output strides of 8 or 16 are suported'
  5. if output_stride == 16: dilations = [1, 6, 12, 18]
  6. elif output_stride == 8: dilations = [1, 12, 24, 36]
  7. self.aspp1 = assp_branch(in_channels, 256, 1, dilation=dilations[0])
  8. self.aspp2 = assp_branch(in_channels, 256, 3, dilation=dilations[1])
  9. self.aspp3 = assp_branch(in_channels, 256, 3, dilation=dilations[2])
  10. self.aspp4 = assp_branch(in_channels, 256, 3, dilation=dilations[3])
  11. self.avg_pool = nn.Sequential(
  12. nn.AdaptiveAvgPool2d((1, 1)),
  13. nn.Conv2d(in_channels, 256, 1, bias=False),
  14. nn.BatchNorm2d(256),
  15. nn.ReLU(inplace=True))
  16. self.conv1 = nn.Conv2d(256*5, 256, 1, bias=False)
  17. self.bn1 = nn.BatchNorm2d(256)
  18. self.relu = nn.ReLU(inplace=True)
  19. self.dropout = nn.Dropout(0.5)
  20. initialize_weights(self)
  21. def forward(self, x):
  22. x1 = self.aspp1(x)
  23. x2 = self.aspp2(x)
  24. x3 = self.aspp3(x)
  25. x4 = self.aspp4(x)
  26. x5 = F.interpolate(self.avg_pool(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True)
  27. x = self.conv1(torch.cat((x1, x2, x3, x4, x5), dim=1))
  28. x = self.bn1(x)
  29. x = self.dropout(self.relu(x))
  30. return x

2.4、class Decoder

Decoder部分属于最后一部分了,其接受backbone的低级feature maps和ASPP输出的feature maps,并对其分别进行了降维、上采样,然后concat,最后经过一组3*3卷积块后输出。其类定义代码如下:

  1. class Decoder(nn.Module):
  2. def __init__(self, low_level_channels, num_classes):
  3. super(Decoder, self).__init__()
  4. self.conv1 = nn.Conv2d(low_level_channels, 48, 1, bias=False)
  5. self.bn1 = nn.BatchNorm2d(48)
  6. self.relu = nn.ReLU(inplace=True)
  7. # Table 2, best performance with two 3x3 convs
  8. self.output = nn.Sequential(
  9. nn.Conv2d(48+256, 256, 3, stride=1, padding=1, bias=False),
  10. nn.BatchNorm2d(256),
  11. nn.ReLU(inplace=True),
  12. nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
  13. nn.BatchNorm2d(256),
  14. nn.ReLU(inplace=True),
  15. nn.Dropout(0.1),
  16. nn.Conv2d(256, num_classes, 1, stride=1),
  17. )
  18. initialize_weights(self)
  19. def forward(self, x, low_level_features):
  20. low_level_features = self.conv1(low_level_features)
  21. low_level_features = self.relu(self.bn1(low_level_features))
  22. H, W = low_level_features.size(2), low_level_features.size(3)
  23. x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
  24. x = self.output(torch.cat((low_level_features, x), dim=1))
  25. return x

需要注意的是,该代码将最后的4倍上采样插值的操作放到Decoder外面了,这一点与论文稍有差别,但只是归属不同,效果是一样的,不影响使用。

3、总结

本文主要结合DeepLabv3+的网络结构和开源代码进行了对照解析。通过对网络结构中的每一块的作用及其代码实现中的细节进行了解,就能够对该网络模型有一个宏观和微观上的把握,理解层次也更加深入了。

对网络结构理解清晰透彻之后,就可以尝试魔改了,比如把线性插值上采样改为反卷积,是不是就可以避免转TensorRT时对上采样操作的不支持呢?这个想法后续有时间会尝试一下。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/678168
推荐阅读
相关标签
  

闽ICP备14008679号