当前位置:   article > 正文

图像语义分割 -- U-Net_u-net语义分割代码

u-net语义分割代码

一:FCN回顾

上一博文我们学习了FCN,有不同的特征融合版本。
至于为什么要进行特征能融合呢?由于池化操作的存在,浅层卷积视野小,具体一些,细节更加详细,越深层的视野大,图像越小,越粗粒度,细节也是越来越模糊,所以,下采样的好处是,带来了感受域的提升,同时也减少计算量,但是却忽略了很多细节,让图像变得平湖模糊,因此,作者将浅层的细节特征也进行了特征融合。

较浅的卷积层(靠前的)的感受域比较小,学习感知细节部分的能力强,较深的隐藏层 (靠后的),感受域相对较大,适合学习较为整体的、相对更宏观一些的特征。
所以在较深的卷积层上进行反卷积还原,自然会丢失很多细节特征。
于是我们会在反卷积步骤时,考虑采用一部分较浅层的反卷积信息辅助叠加,更好的优化分割结果的精度:

至于效果具体是如何呢?
作者在原文种给出3种网络结果对比,明显可以看出效果:FCN-32s < FCN-16s < FCN-8s,即使用多层feature融合有利于提高分割准确性。
在这里插入图片描述

二:U-Net
Unet 基于 Encoder-Decoder 结构,通过拼接的方式实现特征融合,结构简明且稳定,如果你有语义分割的问题,尤其在样本数据量不大的情况下,表现还是可以的。其图示如下:

在这里插入图片描述

如上图,Unet 网络结构是对称的,形似英文字母 U 所以被称为 Unet。整张图都是由蓝/白色框与各种颜色的箭头组成,其中,蓝/白色框表示 feature map;蓝色箭头表示 3x3 卷积,用于特征提取;灰色箭头表示 skip-connection,用于特征融合;红色箭头表示池化 pooling,用于降低维度;绿色箭头表示上采样 upsample,用于恢复维度;青色箭头表示 1x1 卷积,用于输出结果。

Encoder 由卷积操作和下采样操作组成,文中所用的卷积结构统一为 3x3 的卷积核,padding 为 0 ,striding 为 1。pytorch 代码:

nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True))
  • 1
  • 2
  • 3

另外,Encoder中的下采样采用的是maxpooling。pytorch 代码:

nn.MaxPool2d(kernel_size=2, stride=2)
  • 1

Decoder中feature map 经过 Decoder 恢复原始分辨率,该过程除了卷积比较关键的步骤就是 upsampling 与 skip-connection。

Upsampling 上采样常用的方式有两种:1.FCN 中介绍的反卷积;2. 插值。其中在插值方法中,bilinear 双线性插值的综合表现较好也较为常见。pytorch 代码:

nn.Upsample(scale_factor=2, mode='bilinear')
  • 1

可用以下例子看看bilinear插值的效果。

import torch
from torch import nn

x = torch.rand(2, 3, 3, 2)
model = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # [2, 3, 6, 4]
model = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=True)  # [2, 3, 9, 6]
y = model(x)
print(y.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

FNN 网络要想获得好效果,skip-connection 基本必不可少。Unet 的Decoder中这一关键步骤融合了底层信息的位置信息与深层特征的语义信息,pytorch 代码:

torch.cat([low_layer_features, deep_layer_features], dim=1)
  • 1

这里需要注意的是,FCN 中深层信息与浅层信息融合是通过对应像素相加的方式,而 Unet 是通过拼接的方式。测试代码如下:

import torch
from torch import nn

low_layer_features = torch.rand(2, 3, 3, 2)
deep_layer_features = torch.rand(2, 3, 3, 2)
y = torch.cat([low_layer_features, deep_layer_features], dim=1)  # [2, 6, 3, 2]
print(y.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

三:U-Net具体代码实现
好了,U-Net的结构也是分析完了,关键的步骤操作和试验也差不多了,现在我们来搭建下U-Net网络吧。完整代码如下:

from torch import nn
import torch


class UNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2):  # num_classes,此处为 二分类值为2
        super(UNet, self).__init__()
        # == Encoder ==
        # 1. extract feayures, conv1
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.subpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 2. extract feayures, conv2
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.subpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 3. extract feayures, conv3
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.subpool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 4. extract feayures, conv4
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.subpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 5. extract feayures, conv5
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, 3),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),

            nn.Conv2d(1024, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        # == Decoder ==
        self.uppool1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv6 = nn.Sequential(
            nn.Conv2d(1024, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.uppool2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv7 = nn.Sequential(
            nn.Conv2d(512, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.uppool3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv8 = nn.Sequential(
            nn.Conv2d(256, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.uppool4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv9 = nn.Sequential(
            nn.Conv2d(128, 64, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, num_classes, 1),
            nn.BatchNorm2d(num_classes),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # === encoder
        conv1 = self.conv1(x)
        conv1_sub = self.subpool1(conv1)

        conv2 = self.conv2(conv1_sub)
        conv2_sub = self.subpool2(conv2)

        conv3 = self.conv3(conv2_sub)
        conv3_sub = self.subpool3(conv3)

        conv4 = self.conv4(conv3_sub)
        conv4_sub = self.subpool4(conv4)

        conv5 = self.conv5(conv4_sub)  # U型的最低端,它既是是encoder输出,也是decoder的输入。

        # === deoder
        conv1_up = self.uppool1(conv5)
        conv6 = self.conv6(torch.cat([conv4, conv1_up], dim=1))

        conv2_up = self.uppool2(conv6)
        conv7 = self.conv7(torch.cat([conv3, conv2_up], dim=1))

        conv3_up = self.uppool3(conv7)
        conv8 = self.conv8(torch.cat([conv2, conv3_up], dim=1))

        conv4_up = self.uppool4(conv8)
        conv9 = self.conv9(torch.cat([conv1, conv4_up], dim=1))

        return conv9


if __name__ == '__main__':
    # model = VGGTest()
    x = torch.rand(64, 1, 572, 572)
    print(x.shape)

    model = UNet(in_channels=x.shape[1])
    # print(model)
    y = model(x)
    print(y.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157

四:和FCN的区别对比
U-Net采用了与FCN完全不同的特征融合方式
与FCN逐点相加不同,U-Net采用将特征在channel维度拼接在一起,形成更“厚”的特征。所以:
语义分割网络,在浅层和深层特征融合时也有2种办法:

  1. FCN式的浅层特征和深层特征逐点相加。
  2. U-Net式的channel维度拼接融合。
    相比其他大型网络,FCN/U-Net还是蛮简单的,就不多废话了。
    总结一下,CNN图像语义分割也就基本上是这个套路:
  3. 下采样+上采样:Convlution + Deconvlution/Resize
  4. 多层次特征融合:特征逐点相加/特征channel维度拼接
  5. 获得像素级别的segement map:对每一个像素点进行判断类别
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/87784
推荐阅读
相关标签
  

闽ICP备14008679号