赞
踩
今天分享一篇2018年的论文《UNet++: A Nested U-Net Architecture for Medical Image Segmentation》,已经有很多博客将其解读的很详细了,这里不再重复。这里主要提一下论文亮点:
简而言之,文章主要对 U-Net 中的 plain skip cennections 进行修改,作者认为Encoder 和 Decoder 的不同语义之间直接连接效果并不好,提出了嵌套的和稠密的跳跃连接来减小不同特征图之间的语义差距,达到改善分割效果的目的,并用了深监督进行训练,训练后可进行模型剪枝。
论文原文:The main idea behind UNet++ is to bridge the semantic gap between the feature maps of the encoder and decoder prior to fusion。
U-Net++作者周纵苇解读:U-net++解读
论文地址:
《UNet++: A Nested U-Net Architecture for Medical Image Segmentation》
《UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation》
代码地址:UNetPlusPlus
下图为论文中给出的U-Net++网络结构,清晰明了:
黑色部分:是原始 U-Net 网络结构;
绿色和蓝色部分:是稠密连接( nested and dense skip connections ),改善原始 U-Net 结构中的 plain skip connections;
红色部分:因添加稠密连接后,在计算 loss function 时,梯度无法经过绿色和蓝色 bolck 区域,所以添加了红色部分进行深监督,实现训练;同时也便于后期进行模型剪枝;
不同深度网络:U-Net++因自身结构的原因,具有不同的网络深度,如L1、L2、L3 和 L4;
不同与U-Net 网络中直接进行skip connection,U-Net++是嵌套的和稠密的跳跃连接来实现的,作者认为当解码和编码的特征图的语义相似进行融合时,它学习的效果会更好;
在 UNet 中高分辨率特征图快速直接地从编码到解码,结果是不相似语义的特征图之间进行融合。与 UNet 的朴素跳跃连接不同,U-Net++ 将高分辨率特征图从 Encoder 网络逐渐地和 Decoder 网络中相应语义的特征图优先进行融合,这个网络它可以更高效地捕获前景对象的深层 ( fine-grained ) 细节。
监督每个分支的U-Net的输出,这样可以解决中间部分无法训练的问题,具体如下:
(1)在图中 X(0,1)、X(0,2)、X(0,3)、X(0,4) 后面加一个1x1的卷积核,将 feature map 的channel 数量变换到与 output_channel 数量一致,以 Dice + Cross Entropy 作为损失函数,来进行训练;
(2)实际训练中,对于不同的loss分支,作者的给出的权重为1:1:1:1;
(3)两种模式
精确模式:将输出的所有分割分支进行平均,得到最终分割结果;
快速模型:将得到的4个分割图,只选择其中一个分支,这个选择决定了模型修剪的程度和速度增益;
作者是在测试阶段,在测试集上进行剪枝的;
作者解释在测试阶段,由于输入的图像只会前向传播,扔掉这部分对前面的输出完全没有的;而在训练阶段,因为既有前向,又有后向传播,被剪掉的部分是会帮助其他部分做权重更新的。因此测试时,剪掉部分对剩余结构不做影响,训练时,剪掉的分对剩余部分有影响;
论文中给出了对于不同的数据集,剪枝结果,不同程度的剪枝可减小不同程度的参数量,剪枝越多则参数量越少,但模型性能会退化,具体剪枝情况由不同数据集而异;此处图中,肺结节在L2就有较好的结果了;
论文中给出了在4种不同数据集上进行测试的结果:
(1)其中 U-Net 为 benchmark, wide U-Net 是加宽后的U-Net结构, 用于单纯增加网络参数量,便于控制变量进行对比;
(2)Table 3 是不同网络的对比,包括总参数大小、IOU系数;
此处展示的是原始 U-Net++ 结构代码,对照网络结构看起来更易理解,存放在helper_functions.py中。代码是2D的,对于3D 图像,添加 depth 即可,同理;
作者也提供了在其他 backbone 进行U Net++化的代码,此处就不详细列举了;
######################################## # 2D Standard ######################################## def standard_unit(input_tensor, stage, nb_filter, kernel_size=3): x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(input_tensor) x = Dropout(dropout_rate, name='dp'+stage+'_1')(x) x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(x) x = Dropout(dropout_rate, name='dp'+stage+'_2')(x) return x Standard UNet++ [Zhou et.al, 2018] Total params: 9,041,601 """ def UNetPlusPlus(img_rows, img_cols, color_type=1, num_class=1, deep_supervision=False): nb_filter = [32,64,128,256,512] # Handle Dimension Ordering for different backends global bn_axis if K.image_dim_ordering() == 'tf': bn_axis = 3 img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input') else: bn_axis = 1 img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input') conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0]) pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1) conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1]) pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1) up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1) conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis) conv1_2 = standard_unit(conv1_2, stage='12', nb_filter=nb_filter[0]) conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2]) pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1) up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1) conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis) conv2_2 = standard_unit(conv2_2, stage='22', nb_filter=nb_filter[1]) up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2) conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis) conv1_3 = standard_unit(conv1_3, stage='13', nb_filter=nb_filter[0]) conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3]) pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1) up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1) conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis) conv3_2 = standard_unit(conv3_2, stage='32', nb_filter=nb_filter[2]) up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2) conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis) conv2_3 = standard_unit(conv2_3, stage='23', nb_filter=nb_filter[1]) up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3) conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis) conv1_4 = standard_unit(conv1_4, stage='14', nb_filter=nb_filter[0]) conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4]) up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1) conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis) conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3]) up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2) conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis) conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2]) up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3) conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis) conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1]) up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4) conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis) conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0]) nestnet_output_1 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_2) nestnet_output_2 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_3) nestnet_output_3 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_3', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_4) nestnet_output_4 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_4', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5) if deep_supervision: # 计算Loss时,4个分支权重为1:1:1:1 model = Model(input=img_input, output=[nestnet_output_1, nestnet_output_2, nestnet_output_3, nestnet_output_4]) else: model = Model(input=img_input, output=[nestnet_output_4]) return model
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。