当前位置:   article > 正文

分享一份Unet代码(多类和单类都有)和一些关于分割的想法_unet多类别分割

unet多类别分割

前言

最近研究了三种分割算法,deeplab-v3-plus,FCN,还有Une。FCN是分割网络的开山之作,可以用来学习,deeplab-v3-plus速度比较慢,精度更高,代码改起来比较复杂。落地的话首选还是UNET,相比较与目标检测的网络,代码简单到爆炸,也推荐作为深度学习的入门网络。

网络结构

可以看到整个网络结构是一个U型的结构,前面部分通过pooling进行下采样,后面部分通过反卷积上采样。中间通过concatenate进行拼接。

inputs = Input((PIXEL, PIXEL, 3))
s = Lambda(lambda x: x / 255) (inputs)
conv1 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(s)
pool1 = AveragePooling2D(pool_size=(2, 2))(conv1)  # 16

conv2 = BatchNormalization(momentum=0.99)(pool1)
conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = BatchNormalization(momentum=0.99)(conv2)
conv2 = Conv2D(64, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = Dropout(0.02)(conv2)
pool2 = AveragePooling2D(pool_size=(2, 2))(conv2)  # 8

conv3 = BatchNormalization(momentum=0.99)(pool2)
conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = BatchNormalization(momentum=0.99)(conv3)
conv3 = Conv2D(128, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = Dropout(0.02)(conv3)
pool3 = AveragePooling2D(pool_size=(2, 2))(conv3)  # 4

conv4 = BatchNormalization(momentum=0.99)(pool3)
conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = BatchNormalization(momentum=0.99)(conv4)
conv4 = Conv2D(256, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = Dropout(0.02)(conv4)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)

conv5 = BatchNormalization(momentum=0.99)(pool4)
conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = BatchNormalization(momentum=0.99)(conv5)
conv5 = Conv2D(512, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
# conv5 = Conv2D(35, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
# drop4 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(pool3)  # 2
pool5 = AveragePooling2D(pool_size=(2, 2))(pool4)  # 1

conv6 = BatchNormalization(momentum=0.99)(pool5)
conv6 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
up7 = (UpSampling2D(size=(2, 2))(conv7))  # 2
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)
merge7 = concatenate([pool4, conv7], axis=3)

conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
up8 = (UpSampling2D(size=(2, 2))(conv8))  # 4
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)
merge8 = concatenate([pool3, conv8], axis=3)

conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
up9 = (UpSampling2D(size=(2, 2))(conv9))  # 8
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)
merge9 = concatenate([pool2, conv9], axis=3)

conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
up10 = (UpSampling2D(size=(2, 2))(conv10))  # 16
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up10)

conv11 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10)
up11 = (UpSampling2D(size=(2, 2))(conv11))  # 32
conv11 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up11)

# conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
outputs = Conv2D(1, (1, 1), activation='sigmoid') (conv12)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66

在这里插入图片描述

关于多分类

网络上一般的博客就给了这样的单分类网络,我同时把我成功修改的多分类也给到读者。其实自己修改也很简单,有几个步骤和坑我来说明一下。

  1. 最后一层的输出:Logistic函数(sigmoid函数)只适用于二分类改成relu函数。
  2. loss函数:同样的,由binary_crossentropy改成mse
  3. 最最关键的原因,我觉得也是很多博主没成功改成多分类的原因:数据集label有问题!
    如果你也是采用labelme标注的话,一定会在json转换的时候发现如下问题:
    3.1 一次只能转换一个
    3.2 多个标签情况下,图片标签失配
    3.3 转换为独热标签错误
    请务必使用我提供的转换脚本(目前网上的处理我,还没看见一个对的),大坑!
    https://github.com/liuzehao/FCN-tools/blob/master/json_to_dataset.py
    使用方法:
 python json_to_dataset.py ./你的json路径
  • 1

注意在使用的时候要修改为你的分类

NAME_LABEL_MAP = {
    '_background_': 0,
    "cat": 1,
    "dog": 2,
}
 
LABEL_NAME_MAP = ['0: _background_',
                  '1: cat',
                  '2: dog']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

github地址

https://github.com/liuzehao/unet-muti

最后想说的

1.如果你也是刚入分割的坑的话,我觉得不用做我之前那些无谓的尝试了,一般问题直接上Unet吧,复杂问题Unet也是最具有扩展能力的。
2.pooling下采样其实可以改成卷积,pooling会导致梯度递减
3.关于分割。虽然分割代码简单,不要以为可以替代目标检测。原因如下:
如果目标之间相似度比较高,请用目标检测的算法。如果说相似程度高,且不适合用回归框表示,可以用点的目标检测方法。血的教训!!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/351108
推荐阅读
相关标签
  

闽ICP备14008679号