赞
踩
一:FCN回顾
上一博文我们学习了FCN,有不同的特征融合版本。
至于为什么要进行特征能融合呢?由于池化操作的存在,浅层卷积视野小,具体一些,细节更加详细,越深层的视野大,图像越小,越粗粒度,细节也是越来越模糊,所以,下采样的好处是,带来了感受域的提升,同时也减少计算量,但是却忽略了很多细节,让图像变得平湖模糊,因此,作者将浅层的细节特征也进行了特征融合。
较浅的卷积层(靠前的)的感受域比较小,学习感知细节部分的能力强,较深的隐藏层 (靠后的),感受域相对较大,适合学习较为整体的、相对更宏观一些的特征。
所以在较深的卷积层上进行反卷积还原,自然会丢失很多细节特征。
于是我们会在反卷积步骤时,考虑采用一部分较浅层的反卷积信息辅助叠加,更好的优化分割结果的精度:
至于效果具体是如何呢?
作者在原文种给出3种网络结果对比,明显可以看出效果:FCN-32s < FCN-16s < FCN-8s,即使用多层feature融合有利于提高分割准确性。
二:U-Net
Unet 基于 Encoder-Decoder 结构,通过拼接的方式实现特征融合,结构简明且稳定,如果你有语义分割的问题,尤其在样本数据量不大的情况下,表现还是可以的。其图示如下:
如上图,Unet 网络结构是对称的,形似英文字母 U 所以被称为 Unet。整张图都是由蓝/白色框与各种颜色的箭头组成,其中,蓝/白色框表示 feature map;蓝色箭头表示 3x3 卷积,用于特征提取;灰色箭头表示 skip-connection,用于特征融合;红色箭头表示池化 pooling,用于降低维度;绿色箭头表示上采样 upsample,用于恢复维度;青色箭头表示 1x1 卷积,用于输出结果。
Encoder 由卷积操作和下采样操作组成,文中所用的卷积结构统一为 3x3 的卷积核,padding 为 0 ,striding 为 1。pytorch 代码:
nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
另外,Encoder中的下采样采用的是maxpooling。pytorch 代码:
nn.MaxPool2d(kernel_size=2, stride=2)
Decoder中feature map 经过 Decoder 恢复原始分辨率,该过程除了卷积比较关键的步骤就是 upsampling 与 skip-connection。
Upsampling 上采样常用的方式有两种:1.FCN 中介绍的反卷积;2. 插值。其中在插值方法中,bilinear 双线性插值的综合表现较好也较为常见。pytorch 代码:
nn.Upsample(scale_factor=2, mode='bilinear')
可用以下例子看看bilinear插值的效果。
import torch
from torch import nn
x = torch.rand(2, 3, 3, 2)
model = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # [2, 3, 6, 4]
model = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=True) # [2, 3, 9, 6]
y = model(x)
print(y.shape)
FNN 网络要想获得好效果,skip-connection 基本必不可少。Unet 的Decoder中这一关键步骤融合了底层信息的位置信息与深层特征的语义信息,pytorch 代码:
torch.cat([low_layer_features, deep_layer_features], dim=1)
这里需要注意的是,FCN 中深层信息与浅层信息融合是通过对应像素相加的方式,而 Unet 是通过拼接的方式。测试代码如下:
import torch
from torch import nn
low_layer_features = torch.rand(2, 3, 3, 2)
deep_layer_features = torch.rand(2, 3, 3, 2)
y = torch.cat([low_layer_features, deep_layer_features], dim=1) # [2, 6, 3, 2]
print(y.shape)
三:U-Net具体代码实现
好了,U-Net的结构也是分析完了,关键的步骤操作和试验也差不多了,现在我们来搭建下U-Net网络吧。完整代码如下:
from torch import nn import torch class UNet(nn.Module): def __init__(self, in_channels=1, num_classes=2): # num_classes,此处为 二分类值为2 super(UNet, self).__init__() # == Encoder == # 1. extract feayures, conv1 self.conv1 = nn.Sequential( nn.Conv2d(in_channels, 64, 3), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.subpool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 2. extract feayures, conv2 self.conv2 = nn.Sequential( nn.Conv2d(64, 128, 3), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3), nn.BatchNorm2d(128), nn.ReLU(inplace=True) ) self.subpool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 3. extract feayures, conv3 self.conv3 = nn.Sequential( nn.Conv2d(128, 256, 3), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, 3), nn.BatchNorm2d(256), nn.ReLU(inplace=True) ) self.subpool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 4. extract feayures, conv4 self.conv4 = nn.Sequential( nn.Conv2d(256, 512, 3), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, 3), nn.BatchNorm2d(512), nn.ReLU(inplace=True) ) self.subpool4 = nn.MaxPool2d(kernel_size=2, stride=2) # 5. extract feayures, conv5 self.conv5 = nn.Sequential( nn.Conv2d(512, 1024, 3), nn.BatchNorm2d(1024), nn.ReLU(inplace=True), nn.Conv2d(1024, 512, 3), nn.BatchNorm2d(512), nn.ReLU(inplace=True) ) # == Decoder == self.uppool1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv6 = nn.Sequential( nn.Conv2d(1024, 512, 3), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 256, 3), nn.BatchNorm2d(256), nn.ReLU(inplace=True) ) self.uppool2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv7 = nn.Sequential( nn.Conv2d(512, 256, 3), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 128, 3), nn.BatchNorm2d(128), nn.ReLU(inplace=True) ) self.uppool3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv8 = nn.Sequential( nn.Conv2d(256, 128, 3), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 64, 3), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.uppool4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv9 = nn.Sequential( nn.Conv2d(128, 64, 3), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, num_classes, 1), nn.BatchNorm2d(num_classes), nn.ReLU(inplace=True) ) def forward(self, x): # === encoder conv1 = self.conv1(x) conv1_sub = self.subpool1(conv1) conv2 = self.conv2(conv1_sub) conv2_sub = self.subpool2(conv2) conv3 = self.conv3(conv2_sub) conv3_sub = self.subpool3(conv3) conv4 = self.conv4(conv3_sub) conv4_sub = self.subpool4(conv4) conv5 = self.conv5(conv4_sub) # U型的最低端,它既是是encoder输出,也是decoder的输入。 # === deoder conv1_up = self.uppool1(conv5) conv6 = self.conv6(torch.cat([conv4, conv1_up], dim=1)) conv2_up = self.uppool2(conv6) conv7 = self.conv7(torch.cat([conv3, conv2_up], dim=1)) conv3_up = self.uppool3(conv7) conv8 = self.conv8(torch.cat([conv2, conv3_up], dim=1)) conv4_up = self.uppool4(conv8) conv9 = self.conv9(torch.cat([conv1, conv4_up], dim=1)) return conv9 if __name__ == '__main__': # model = VGGTest() x = torch.rand(64, 1, 572, 572) print(x.shape) model = UNet(in_channels=x.shape[1]) # print(model) y = model(x) print(y.shape)
四:和FCN的区别对比
U-Net采用了与FCN完全不同的特征融合方式
与FCN逐点相加不同,U-Net采用将特征在channel维度拼接在一起,形成更“厚”的特征。所以:
语义分割网络,在浅层和深层特征融合时也有2种办法:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。