当前位置:   article > 正文

【论文笔记】医学图像分割 U-Net++:A Nested U-Net Architecture_u-net++原文

u-net++原文

1 综述

今天分享一篇2018年的论文《UNet++: A Nested U-Net Architecture for Medical Image Segmentation》,已经有很多博客将其解读的很详细了,这里不再重复。这里主要提一下论文亮点:

简而言之,文章主要对 U-Net 中的 plain skip cennections 进行修改,作者认为Encoder 和 Decoder 的不同语义之间直接连接效果并不好,提出了嵌套的和稠密的跳跃连接来减小不同特征图之间的语义差距,达到改善分割效果的目的,并用了深监督进行训练,训练后可进行模型剪枝。

论文原文:The main idea behind UNet++ is to bridge the semantic gap between the feature maps of the encoder and decoder prior to fusion。

U-Net++作者周纵苇解读:U-net++解读

论文地址:
《UNet++: A Nested U-Net Architecture for Medical Image Segmentation》
《UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation》

代码地址:UNetPlusPlus

2 网络结构

下图为论文中给出的U-Net++网络结构,清晰明了:

黑色部分:是原始 U-Net 网络结构;

绿色和蓝色部分:是稠密连接( nested and dense skip connections ),改善原始 U-Net 结构中的 plain skip connections;

红色部分:因添加稠密连接后,在计算 loss function 时,梯度无法经过绿色和蓝色 bolck 区域,所以添加了红色部分进行深监督,实现训练;同时也便于后期进行模型剪枝;

不同深度网络:U-Net++因自身结构的原因,具有不同的网络深度,如L1、L2、L3 和 L4;

在这里插入图片描述

2.1 skip connection

不同与U-Net 网络中直接进行skip connection,U-Net++是嵌套的和稠密的跳跃连接来实现的,作者认为当解码和编码的特征图的语义相似进行融合时,它学习的效果会更好

在 UNet 中高分辨率特征图快速直接地从编码到解码,结果是不相似语义的特征图之间进行融合。与 UNet 的朴素跳跃连接不同,U-Net++ 将高分辨率特征图从 Encoder 网络逐渐地和 Decoder 网络中相应语义的特征图优先进行融合,这个网络它可以更高效地捕获前景对象的深层 ( fine-grained ) 细节。

2.2 deep supervision

监督每个分支的U-Net的输出,这样可以解决中间部分无法训练的问题,具体如下:

(1)在图中 X(0,1)、X(0,2)、X(0,3)、X(0,4) 后面加一个1x1的卷积核,将 feature map 的channel 数量变换到与 output_channel 数量一致,以 Dice + Cross Entropy 作为损失函数,来进行训练;

(2)实际训练中,对于不同的loss分支,作者的给出的权重为1:1:1:1;

(3)两种模式

精确模式:将输出的所有分割分支进行平均,得到最终分割结果;

快速模型:将得到的4个分割图,只选择其中一个分支,这个选择决定了模型修剪的程度和速度增益;

2.3 模型剪枝

作者是在测试阶段,在测试集上进行剪枝的;

作者解释在测试阶段,由于输入的图像只会前向传播,扔掉这部分对前面的输出完全没有的;而在训练阶段,因为既有前向,又有后向传播,被剪掉的部分是会帮助其他部分做权重更新的。因此测试时,剪掉部分对剩余结构不做影响,训练时,剪掉的分对剩余部分有影响;

论文中给出了对于不同的数据集,剪枝结果,不同程度的剪枝可减小不同程度的参数量,剪枝越多则参数量越少,但模型性能会退化,具体剪枝情况由不同数据集而异;此处图中,肺结节在L2就有较好的结果了;
在这里插入图片描述

3 分割结果对比

3.1 论文用到数据集

在这里插入图片描述

3.2 预测结果

论文中给出了在4种不同数据集上进行测试的结果:

(1)其中 U-Net 为 benchmark, wide U-Net 是加宽后的U-Net结构, 用于单纯增加网络参数量,便于控制变量进行对比;

(2)Table 3 是不同网络的对比,包括总参数大小、IOU系数;
在这里插入图片描述
在这里插入图片描述

4 源码解析

此处展示的是原始 U-Net++ 结构代码,对照网络结构看起来更易理解,存放在helper_functions.py中。代码是2D的,对于3D 图像,添加 depth 即可,同理;

作者也提供了在其他 backbone 进行U Net++化的代码,此处就不详细列举了;

########################################
# 2D Standard
########################################

def standard_unit(input_tensor, stage, nb_filter, kernel_size=3):

    x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(input_tensor)
    x = Dropout(dropout_rate, name='dp'+stage+'_1')(x)
    x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(x)
    x = Dropout(dropout_rate, name='dp'+stage+'_2')(x)

    return x


Standard UNet++ [Zhou et.al, 2018]
Total params: 9,041,601
"""
def UNetPlusPlus(img_rows, img_cols, color_type=1, num_class=1, deep_supervision=False):

    nb_filter = [32,64,128,256,512]

    # Handle Dimension Ordering for different backends
    global bn_axis
    if K.image_dim_ordering() == 'tf':
      bn_axis = 3
      img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input')
    else:
      bn_axis = 1
      img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input')

    conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0])
    pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)

    conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1])
    pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)

    up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1)
    conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis)
    conv1_2 = standard_unit(conv1_2, stage='12', nb_filter=nb_filter[0])

    conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2])
    pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)

    up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
    conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis)
    conv2_2 = standard_unit(conv2_2, stage='22', nb_filter=nb_filter[1])

    up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
    conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis)
    conv1_3 = standard_unit(conv1_3, stage='13', nb_filter=nb_filter[0])

    conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3])
    pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)

    up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
    conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis)
    conv3_2 = standard_unit(conv3_2, stage='32', nb_filter=nb_filter[2])

    up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
    conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis)
    conv2_3 = standard_unit(conv2_3, stage='23', nb_filter=nb_filter[1])

    up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
    conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis)
    conv1_4 = standard_unit(conv1_4, stage='14', nb_filter=nb_filter[0])

    conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4])

    up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
    conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
    conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3])

    up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
    conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis)
    conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2])

    up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
    conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis)
    conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1])

    up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
    conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis)
    conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0])

    nestnet_output_1 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_2)
    nestnet_output_2 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_3)
    nestnet_output_3 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_3', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_4)
    nestnet_output_4 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_4', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5)

    if deep_supervision:
    	# 计算Loss时,4个分支权重为1:1:1:1
        model = Model(input=img_input, output=[nestnet_output_1,
                                               nestnet_output_2,
                                               nestnet_output_3,
                                               nestnet_output_4])
    else:
        model = Model(input=img_input, output=[nestnet_output_4])

    return model
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/760669
推荐阅读
相关标签
  

闽ICP备14008679号