当前位置:   article > 正文

语义分割(三)Unet++_unet++网络结构

unet++网络结构

Unet++

Unet++论文
UNet++是2018年提出的网络,是U-Net的一个加强版本。

Unet++特点

其相对U-Net改进之处主要为:

  1. 网络结合了类DenseNet结构,密集的跳跃连接提高了梯度流动性。
  2. 将U-Net的空心结构填满,连接了编码器和解码器特征图之间的语义鸿沟。
  3. 使用了深度监督,可以进行剪枝。

Unet++网络结构

它的网络结构是这样的:
在这里插入图片描述
从结构图中可以看到,把空心的UNet填满了。水平方向上每一点都有连接,类似DenseNet,这样可以抓取不同层次的特征。由于不同深度的感受野对大小不同的目标敏感程度不一样,浅层的对小目标更敏感,深层对大目标更敏感,通过特征concat拼接到一起,可以整合二者的优点。此外随着网络层数的增加,一层层的下采样会不断丢失信息(大物体边缘信息和小物体本身),所以用感受野小的特征进行补充,让位置信息更准确。
在计算损失函数时,增加了几个梯度,这是由于只用最后的输出计算损失事,中间层就无法反向传播,因此无法训练中间层参数。于是在 X 0 , 1 X^{0,1} X0,1 X 0 , 2 X^{0,2} X0,2 X 0 , 3 X^{0,3} X0,3 X 0 , 4 X^{0,4} X0,4,这四个地方分别计算损失,最后将四个相加来优化(当然也可以加权重),这样的话既考虑深层信息也考虑浅层信息。
此外,在做计算损失之前再加一个1x1卷积,用一个sigmoid作为激活函数,让输出和标签在同一个范围内,并且还有再次提取特征,对通道进行改变等作用。
Unet++的损失函数如下:
在这里插入图片描述
这里使用了一个BCE和Dice Coefficient结合的损失,应用到每一个不同层次的输出。

模型剪枝

这种结构可以看做是以下四个子网络的结合:
在这里插入图片描述
在深监督的过程中,每个子网络的输出已经都是图像的分割结果。这种情况下,如果前面的子网络分割效果足够好,就不需要再继续计算后面子网络的结果了。
剪枝只在测试的时候进行,训练的时候不剪枝。这是因为测试时只会前向而没有后向,而在训练过程中反向传播时几个损失是相互促进的,这时剪枝会有影响训练。
测试时的剪枝,相对简单,例如在用图片测试,网络的前面两层效果好,那就可以使用前两层构建的子网络,后面就可以剪掉不要了。
因此Unet++模型会有两种操作模式:
Accurate mode(精确模式):对所有分割分支的输出计算平均值。
Fast mode(快速模式):从其中的所有的分支输出中选择一个作为输出,该选择决定了Unet++模型的剪枝程度和速度增益。

Unet++模型实现

下面给出了Unet++模型的keras实现:

###########################
#Unet++ loss              #
###########################
def dice_coef(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def bce_dice_loss(y_true, y_pred):
    return 0.5 * binary_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred)
    #return 0.5 * categorical_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred)
#############################
#Unet++ conv and upsampling #
#############################
def conv_drop(inputs, filters):
    conv1 = Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu',
                   kernel_initializer='he_normal')(inputs)
    #drop1 = Dropout(rate=0.5)(conv1)
    conv2 = Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu',
                   kernel_initializer='he_normal')(conv1)
    #drop2 = Dropout(rate=0.5)(conv2)
    return conv2
    
def upsampling(inputs, filters):
    up = UpSampling2D(size=(2, 2))(inputs)
    conv = Conv2D(filters=filters, kernel_size=2, activation='relu', padding='same',
                  kernel_initializer='he_normal')(up)
    return conv
###############################
# Unet++                      #
###############################
def Unet_plusplus(input_size=(224, 224, 1), n_class=2, filters=(32, 64, 128, 256, 512), re_shape=False):
    inputs = Input(shape=input_size)
    ## l1
    conv0_0 = conv_drop(inputs=inputs, filters=filters[0])
    pool00_10 = MaxPool2D(pool_size=(2, 2))(conv0_0)
    conv1_0 = conv_drop(inputs=pool00_10, filters=filters[1])
    up10_01 = upsampling(inputs=conv1_0, filters=filters[0])
    concat1_1 = concatenate([up10_01, conv0_0], axis=3)
    conv0_1 = conv_drop(inputs=concat1_1, filters=filters[0])
    ## l2
    pool10_20 = MaxPool2D(pool_size=(2, 2))(conv1_0)
    conv2_0 = conv_drop(inputs=pool10_20, filters=filters[2])
    up20_11 = upsampling(inputs=conv2_0, filters=filters[1])
    concat2_1 = concatenate([up20_11, conv1_0], axis=3)
    conv1_1 = conv_drop(inputs=concat2_1, filters=filters[1])
    up11_02 = upsampling(inputs=conv1_1, filters=filters[0])
    concat2_2 = concatenate([up11_02, conv0_0, conv0_1], axis=3)
    conv0_2 = conv_drop(inputs=concat2_2, filters=filters[0])
    ##l3
    pool20_30 = MaxPool2D(pool_size=(2, 2))(conv2_0)
    conv3_0 = conv_drop(inputs=pool20_30, filters=filters[3])
    up30_21 = upsampling(inputs=conv3_0, filters=filters[2])
    concat3_1 = concatenate([up30_21, conv2_0], axis=3)
    conv2_1 = conv_drop(inputs=concat3_1, filters=filters[2])
    up21_12 = upsampling(inputs=conv2_1, filters=filters[1])
    concat3_2 = concatenate([up21_12, conv1_0, conv1_1], axis=3)
    conv1_2 = conv_drop(inputs=concat3_2, filters=filters[1])
    up12_03 = upsampling(inputs=conv1_2, filters=filters[0])
    concat3_3 = concatenate([up12_03, conv0_0, conv0_1, conv0_2], axis=3)
    conv0_3 = conv_drop(inputs=concat3_3, filters=filters[0])
    ## l4
    pool30_40 = MaxPool2D(pool_size=(2, 2))(conv3_0)
    conv4_0 = conv_drop(inputs=pool30_40, filters=filters[4])
    up40_31 = upsampling(inputs=conv4_0, filters=filters[3])
    concat4_1 = concatenate([up40_31, conv3_0], axis=3)
    conv3_1 = conv_drop(inputs=concat4_1, filters=filters[3])
    up31_22 = upsampling(inputs=conv3_1, filters=filters[2])
    concat4_2 = concatenate([up31_22, conv2_0, conv2_1], axis=3)
    conv2_2 = conv_drop(inputs=concat4_2, filters=filters[2])
    up22_13 = upsampling(inputs=conv2_2, filters=filters[1])
    concat4_3 = concatenate([up22_13, conv1_0, conv1_1, conv1_2], axis=3)
    conv1_3 = conv_drop(inputs=concat4_3, filters=filters[1])
    up13_04 = upsampling(inputs=conv1_3, filters=filters[0])
    concat4_4 = concatenate([up13_04, conv0_0, conv0_1, conv0_2, conv0_3], axis=3)
    conv0_4 = conv_drop(inputs=concat4_4, filters=filters[0])
    ## output
    l1_conv_out = Conv2D(filters=n_class, kernel_size=1, padding='same', kernel_initializer='he_normal')(conv0_1)
    l2_conv_out = Conv2D(filters=n_class, kernel_size=1, padding='same', kernel_initializer='he_normal')(conv0_2)
    l3_conv_out = Conv2D(filters=n_class, kernel_size=1, padding='same', kernel_initializer='he_normal')(conv0_3)
    l4_conv_out = Conv2D(filters=n_class, kernel_size=1, padding='same', kernel_initializer='he_normal')(conv0_4)

    if re_shape==True:
        l1_conv_out = Reshape((input_size[0] * input_size[1], n_class))(l1_conv_out)
        l2_conv_out = Reshape((input_size[0] * input_size[1], n_class))(l2_conv_out)
        l3_conv_out = Reshape((input_size[0] * input_size[1], n_class))(l3_conv_out)
        l4_conv_out = Reshape((input_size[0] * input_size[1], n_class))(l4_conv_out)

    l1_out = Activation('sigmoid', name='l1_out')(l1_conv_out)
    l2_out = Activation('sigmoid', name='l2_out')(l2_conv_out)
    l3_out = Activation('sigmoid', name='l3_out')(l3_conv_out)
    l4_out = Activation('sigmoid', name='l4_out')(l4_conv_out)

    model = Model(input=inputs, output=[l1_out, l2_out, l3_out, l4_out])
    #model = Model(input=inputs, output=l4_out)
    model.summary()
    losses= {
        'l1_out': bce_dice_loss,
        'l2_out': bce_dice_loss,
        'l3_out': bce_dice_loss,
        'l4_out': bce_dice_loss,
    }
    model.compile(optimizer=Adam(lr=1e-4), loss=losses, metrics=['accuracy'])
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105

代码中定义了Unet++的损失函数以及网络结构。希望对大家理解Unet++有所帮助。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/235745
推荐阅读
  

闽ICP备14008679号