赞
踩
Unet++论文
UNet++是2018年提出的网络,是U-Net的一个加强版本。
其相对U-Net改进之处主要为:
它的网络结构是这样的:
从结构图中可以看到,把空心的UNet填满了。水平方向上每一点都有连接,类似DenseNet,这样可以抓取不同层次的特征。由于不同深度的感受野对大小不同的目标敏感程度不一样,浅层的对小目标更敏感,深层对大目标更敏感,通过特征concat拼接到一起,可以整合二者的优点。此外随着网络层数的增加,一层层的下采样会不断丢失信息(大物体边缘信息和小物体本身),所以用感受野小的特征进行补充,让位置信息更准确。
在计算损失函数时,增加了几个梯度,这是由于只用最后的输出计算损失事,中间层就无法反向传播,因此无法训练中间层参数。于是在
X
0
,
1
X^{0,1}
X0,1,
X
0
,
2
X^{0,2}
X0,2,
X
0
,
3
X^{0,3}
X0,3,
X
0
,
4
X^{0,4}
X0,4,这四个地方分别计算损失,最后将四个相加来优化(当然也可以加权重),这样的话既考虑深层信息也考虑浅层信息。
此外,在做计算损失之前再加一个1x1卷积,用一个sigmoid作为激活函数,让输出和标签在同一个范围内,并且还有再次提取特征,对通道进行改变等作用。
Unet++的损失函数如下:
这里使用了一个BCE和Dice Coefficient结合的损失,应用到每一个不同层次的输出。
这种结构可以看做是以下四个子网络的结合:
在深监督的过程中,每个子网络的输出已经都是图像的分割结果。这种情况下,如果前面的子网络分割效果足够好,就不需要再继续计算后面子网络的结果了。
剪枝只在测试的时候进行,训练的时候不剪枝。这是因为测试时只会前向而没有后向,而在训练过程中反向传播时几个损失是相互促进的,这时剪枝会有影响训练。
测试时的剪枝,相对简单,例如在用图片测试,网络的前面两层效果好,那就可以使用前两层构建的子网络,后面就可以剪掉不要了。
因此Unet++模型会有两种操作模式:
Accurate mode(精确模式):对所有分割分支的输出计算平均值。
Fast mode(快速模式):从其中的所有的分支输出中选择一个作为输出,该选择决定了Unet++模型的剪枝程度和速度增益。
下面给出了Unet++模型的keras实现:
########################### #Unet++ loss # ########################### def dice_coef(y_true, y_pred): smooth = 1. y_true_f = K.flatten(y_true) y_pred_f = K.flatten(y_pred) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) def bce_dice_loss(y_true, y_pred): return 0.5 * binary_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred) #return 0.5 * categorical_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred) ############################# #Unet++ conv and upsampling # ############################# def conv_drop(inputs, filters): conv1 = Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal')(inputs) #drop1 = Dropout(rate=0.5)(conv1) conv2 = Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal')(conv1) #drop2 = Dropout(rate=0.5)(conv2) return conv2 def upsampling(inputs, filters): up = UpSampling2D(size=(2, 2))(inputs) conv = Conv2D(filters=filters, kernel_size=2, activation='relu', padding='same', kernel_initializer='he_normal')(up) return conv ############################### # Unet++ # ############################### def Unet_plusplus(input_size=(224, 224, 1), n_class=2, filters=(32, 64, 128, 256, 512), re_shape=False): inputs = Input(shape=input_size) ## l1 conv0_0 = conv_drop(inputs=inputs, filters=filters[0]) pool00_10 = MaxPool2D(pool_size=(2, 2))(conv0_0) conv1_0 = conv_drop(inputs=pool00_10, filters=filters[1]) up10_01 = upsampling(inputs=conv1_0, filters=filters[0]) concat1_1 = concatenate([up10_01, conv0_0], axis=3) conv0_1 = conv_drop(inputs=concat1_1, filters=filters[0]) ## l2 pool10_20 = MaxPool2D(pool_size=(2, 2))(conv1_0) conv2_0 = conv_drop(inputs=pool10_20, filters=filters[2]) up20_11 = upsampling(inputs=conv2_0, filters=filters[1]) concat2_1 = concatenate([up20_11, conv1_0], axis=3) conv1_1 = conv_drop(inputs=concat2_1, filters=filters[1]) up11_02 = upsampling(inputs=conv1_1, filters=filters[0]) concat2_2 = concatenate([up11_02, conv0_0, conv0_1], axis=3) conv0_2 = conv_drop(inputs=concat2_2, filters=filters[0]) ##l3 pool20_30 = MaxPool2D(pool_size=(2, 2))(conv2_0) conv3_0 = conv_drop(inputs=pool20_30, filters=filters[3]) up30_21 = upsampling(inputs=conv3_0, filters=filters[2]) concat3_1 = concatenate([up30_21, conv2_0], axis=3) conv2_1 = conv_drop(inputs=concat3_1, filters=filters[2]) up21_12 = upsampling(inputs=conv2_1, filters=filters[1]) concat3_2 = concatenate([up21_12, conv1_0, conv1_1], axis=3) conv1_2 = conv_drop(inputs=concat3_2, filters=filters[1]) up12_03 = upsampling(inputs=conv1_2, filters=filters[0]) concat3_3 = concatenate([up12_03, conv0_0, conv0_1, conv0_2], axis=3) conv0_3 = conv_drop(inputs=concat3_3, filters=filters[0]) ## l4 pool30_40 = MaxPool2D(pool_size=(2, 2))(conv3_0) conv4_0 = conv_drop(inputs=pool30_40, filters=filters[4]) up40_31 = upsampling(inputs=conv4_0, filters=filters[3]) concat4_1 = concatenate([up40_31, conv3_0], axis=3) conv3_1 = conv_drop(inputs=concat4_1, filters=filters[3]) up31_22 = upsampling(inputs=conv3_1, filters=filters[2]) concat4_2 = concatenate([up31_22, conv2_0, conv2_1], axis=3) conv2_2 = conv_drop(inputs=concat4_2, filters=filters[2]) up22_13 = upsampling(inputs=conv2_2, filters=filters[1]) concat4_3 = concatenate([up22_13, conv1_0, conv1_1, conv1_2], axis=3) conv1_3 = conv_drop(inputs=concat4_3, filters=filters[1]) up13_04 = upsampling(inputs=conv1_3, filters=filters[0]) concat4_4 = concatenate([up13_04, conv0_0, conv0_1, conv0_2, conv0_3], axis=3) conv0_4 = conv_drop(inputs=concat4_4, filters=filters[0]) ## output l1_conv_out = Conv2D(filters=n_class, kernel_size=1, padding='same', kernel_initializer='he_normal')(conv0_1) l2_conv_out = Conv2D(filters=n_class, kernel_size=1, padding='same', kernel_initializer='he_normal')(conv0_2) l3_conv_out = Conv2D(filters=n_class, kernel_size=1, padding='same', kernel_initializer='he_normal')(conv0_3) l4_conv_out = Conv2D(filters=n_class, kernel_size=1, padding='same', kernel_initializer='he_normal')(conv0_4) if re_shape==True: l1_conv_out = Reshape((input_size[0] * input_size[1], n_class))(l1_conv_out) l2_conv_out = Reshape((input_size[0] * input_size[1], n_class))(l2_conv_out) l3_conv_out = Reshape((input_size[0] * input_size[1], n_class))(l3_conv_out) l4_conv_out = Reshape((input_size[0] * input_size[1], n_class))(l4_conv_out) l1_out = Activation('sigmoid', name='l1_out')(l1_conv_out) l2_out = Activation('sigmoid', name='l2_out')(l2_conv_out) l3_out = Activation('sigmoid', name='l3_out')(l3_conv_out) l4_out = Activation('sigmoid', name='l4_out')(l4_conv_out) model = Model(input=inputs, output=[l1_out, l2_out, l3_out, l4_out]) #model = Model(input=inputs, output=l4_out) model.summary() losses= { 'l1_out': bce_dice_loss, 'l2_out': bce_dice_loss, 'l3_out': bce_dice_loss, 'l4_out': bce_dice_loss, } model.compile(optimizer=Adam(lr=1e-4), loss=losses, metrics=['accuracy']) return model
代码中定义了Unet++的损失函数以及网络结构。希望对大家理解Unet++有所帮助。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。