当前位置:   article > 正文

Keras框架下的U-Net实现(可实现多类分割+代码)_2d u-net可以进行多类分割吗

2d u-net可以进行多类分割吗

U-Net网络结构

在这里插入图片描述
U-Net是一种全卷积网络,当卷积操作采用"padding = same"的形式,卷积不会使图片尺寸变小;maxpooling操作以2为步长,得到的结果是原图的 1 4 \frac14 41(宽和高均为原图的 1 2 \frac12 21);后面有对应的上采样层恢复图像的尺寸。所以,最终可以得到和原图大小一样的分割结果。原文地址

U-Net的Keras实现

具体代码详见Github

Readme

可以在configuration.txt中修改适合自己数据的参数,如路径(数据路径,模型保存路径)、图像参数(h,w,channels,classes)、网络参数(N_epoch,batchsize)等。

部分代码

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Author:CaoZhihui
...
# --Build a net work-- #
def get_net():
    inputs = Input(shape=(N_channels, img_h, img_w))
    # Block 1
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(conv1)

    # Block 2
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(conv2)

    # Block 3
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(conv3)

    # Block 4
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(conv4)

    # Block 5
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv5)

    up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

    up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

    up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

    up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up9)

    conv10 = Conv2D(C, (1, 1), activation='relu', kernel_initializer='he_normal')(conv9)

    reshape = Reshape((C, img_h * img_w), input_shape=(C, img_h, img_w))(conv10)
    reshape = Permute((2, 1))(reshape)

    activation = Activation('softmax')(reshape)

    model = Model(input=inputs, output=activation)

    model.compile(optimizer=Adam(lr=1.0e-4), loss='categorical_crossentropy', metrics=['accuracy'])

    return model
...

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60

后记

暂且只贴出github上的代码,待日后时间充足再详细介绍并做补充。
未完,待续…

  这个数据集我已经没有了,所以很多问题我没法自己实验来解答了;因为这个是当初入门时写的代码,好多地方不是很严谨,可能移植性并不是很好。
  后面会慢慢抽空更新pytorch框架下resnet,resnext等backbone的U-Net、U-Net++(Nested U-Net)、E-Net等,敬请期待…
  绝对不鸽

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号