赞
踩
在第一篇文章中我们提到“在encoder部分,主要包括了backbone(DCNN)、ASPP两大部分”,在这里的backbone就是mobilenetv2网络结构和xception网络结构,而ASPP结构就是深层网络结构,其网络结构如下:
ASPP网络结构的原理其实很简单,可以看博文1.deeplabv3+网络结构及原理-CSDN博客,该博文有介绍。以上网络结构里的rate表示空洞卷积核的大小,显然,该网络结构总共5层卷积处理,之后再将不同的层用concat堆叠,最后再用1x1的卷积核整合特征,转换为图片中绿色的层。
下面深层网络结构的代码如下:
- #-----------------------------------------#
- # ASPP特征提取模块
- # 利用不同膨胀率的膨胀卷积进行特征提取
- #-----------------------------------------#
- class ASPP(nn.Module):
- def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
- super(ASPP, self).__init__()
- self.branch1 = nn.Sequential(
- nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
- nn.BatchNorm2d(dim_out, momentum=bn_mom),
- nn.ReLU(inplace=True),
- )
- self.branch2 = nn.Sequential(
- nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True),
- nn.BatchNorm2d(dim_out, momentum=bn_mom),
- nn.ReLU(inplace=True),
- )
- self.branch3 = nn.Sequential(
- nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),
- nn.BatchNorm2d(dim_out, momentum=bn_mom),
- nn.ReLU(inplace=True),
- )
- self.branch4 = nn.Sequential(
- nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),
- nn.BatchNorm2d(dim_out, momentum=bn_mom),
- nn.ReLU(inplace=True),
- )
- self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
- self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
- self.branch5_relu = nn.ReLU(inplace=True)
-
- self.conv_cat = nn.Sequential(
- nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
- nn.BatchNorm2d(dim_out, momentum=bn_mom),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- [b, c, row, col] = x.size()
- # -----------------------------------------#
- # 一共五个分支
- # -----------------------------------------#
- conv1x1 = self.branch1(x)
- conv3x3_1 = self.branch2(x)
- conv3x3_2 = self.branch3(x)
- conv3x3_3 = self.branch4(x)
- # -----------------------------------------#
- # 第五个分支,全局平均池化+卷积
- # -----------------------------------------#
- global_feature = torch.mean(x, 2, True)
- global_feature = torch.mean(global_feature, 3, True)
- global_feature = self.branch5_conv(global_feature)
- global_feature = self.branch5_bn(global_feature)
- global_feature = self.branch5_relu(global_feature)
- global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
-
- # -----------------------------------------#
- # 将五个分支的内容堆叠起来
- # 然后1x1卷积整合特征。
- # -----------------------------------------#
- feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
- result = self.conv_cat(feature_cat)
- return result
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。