当前位置:   article > 正文

Keras实现SegNet

Keras实现SegNet

我真服了原来我之前用tf复现SegNet给复现错了
在网上试了多个版本代码,折腾了好久,现在终于复现对了,代码也跑通了
SegNet的架构比较老了,这几年都没人更新代码了,我这里算是提供一个最近能跑通的版本的代码吧

tf版本2.4.1

首先主要是构建两个类来实现池化索引,这里经过反复尝试我懵懵懂懂地解决了其它代码直接搬运过来导致的各种报错

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer


class MaxPoolingWithArgmax2D(Layer):
    def __init__(self, pool_size=(2, 2), strides=(2, 2), padding='same', **kwargs):
        super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
        self.padding = padding
        self.pool_size = pool_size
        self.strides = strides

    def call(self, inputs, **kwargs):
        padding = self.padding
        pool_size = self.pool_size
        strides = self.strides
        if K.backend() == 'tensorflow':
            ksize = [1, pool_size[0], pool_size[1], 1]
            padding = padding.upper()
            strides = [1, strides[0], strides[1], 1]
            output, argmax = tf.nn.max_pool_with_argmax(inputs, ksize=ksize, strides=strides, padding=padding)
        else:
            errmsg = '{} backend is not supported for layer {}'.format(K.backend(), type(self).__name__)
            raise NotImplementedError(errmsg)
        argmax = K.cast(argmax, K.floatx())
        return [output, argmax]

    def compute_output_shape(self, input_shape):
        ratio = (1, 2, 2, 1)
        output_shape = [dim // ratio[idx] if dim is not None else None for idx, dim in enumerate(input_shape)]
        output_shape = tuple(output_shape)
        return [output_shape, output_shape]

    def compute_mask(self, inputs, mask=None):
        return 2 * [None]
    
    def get_config(self):
        config = super(MaxPoolingWithArgmax2D, self).get_config()
        config.update({
            "pool_size": self.pool_size,
            "strides": self.strides,
            "padding": self.padding,
        })
        return config


class MaxUnpooling2D(Layer):
    def __init__(self, size=(2, 2), **kwargs):
        super(MaxUnpooling2D, self).__init__(**kwargs)
        self.size = size

    def call(self, inputs, output_shape=None):
        updates, mask = inputs[0], inputs[1]
        with tf.compat.v1.variable_scope(self.name):
            mask = K.cast(mask, 'int32')
            input_shape = tf.shape(updates, out_type='int32')
            #  calculation new shape
            if output_shape is None:
                output_shape = (input_shape[0], input_shape[1] * self.size[0], input_shape[2] * self.size[1], input_shape[3])
                self.output_shape1 = output_shape

        # calculation indices for batch, height, width and feature maps
        one_like_mask = K.ones_like(mask, dtype='int32')
        batch_shape = K.concatenate([[input_shape[0]], [1], [1], [1]], axis=0)
        batch_range = K.reshape(tf.range(output_shape[0], dtype='int32'), shape=batch_shape)
        b = one_like_mask * batch_range
        y = mask // (output_shape[2] * output_shape[3])
        x = (mask // output_shape[3]) % output_shape[2]
        feature_range = tf.range(output_shape[3], dtype='int32')
        f = one_like_mask * feature_range

        # transpose indices & reshape update values to one dimension
        updates_size = tf.size(updates)
        indices = K.transpose(K.reshape(K.stack([b, y, x, f]), [4, updates_size]))
        values = K.reshape(updates, [updates_size])
        ret = tf.scatter_nd(indices, values, output_shape)
        input_shape = updates.shape
        out_shape = [-1,
                     input_shape[1] * self.size[0],
                     input_shape[2] * self.size[1],
                     input_shape[3]]
        return K.reshape(ret, out_shape)

    def compute_output_shape(self, input_shape):
        mask_shape = input_shape[1]
        return mask_shape[0], mask_shape[1] * self.size[0], mask_shape[2] * self.size[1], mask_shape[3]
    
    def get_config(self):
        config = super(MaxUnpooling2D, self).get_config()
        config.update({
            "size": self.size,
        })
        return config
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93

另外SegNet网络主体部分,注意池化和反池化的时候filters数量要对得上

def SegNet(fNum, dates, lossweights, filters=64):
    inputs = keras.layers.Input((fNum*dates, img_h, img_w))
    inputs0 = keras.layers.Lambda(reshapes2)(inputs) # 针对我数据的reshape

    # Encoder
    conv1 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(inputs0)
    conv1 = keras.layers.BatchNormalization()(conv1)
    conv1 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(conv1)
    conv1 = keras.layers.BatchNormalization()(conv1)
    pool1, idx1 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv1)

    conv2 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = keras.layers.BatchNormalization()(conv2)
    conv2 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(conv2)
    conv2 = keras.layers.BatchNormalization()(conv2)
    pool2, idx2 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv2)

    conv3 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = keras.layers.BatchNormalization()(conv3)
    conv3 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(conv3)
    conv3 = keras.layers.BatchNormalization()(conv3)
    pool3, idx3 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv3)

    conv4 = keras.layers.Conv2D(filters*8, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = keras.layers.BatchNormalization()(conv4)
    conv4 = keras.layers.Conv2D(filters*8, (3, 3), activation='relu', padding='same')(conv4)
    conv4 = keras.layers.BatchNormalization()(conv4)
    pool4, idx4 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv4)

    # Decoder
    up5 = MaxUnpooling2D((2,2))([pool4, idx4])
    conv5 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(up5)
    conv5 = keras.layers.BatchNormalization()(conv5)
    conv5 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(conv5)
    conv5 = keras.layers.BatchNormalization()(conv5)

    up6 = MaxUnpooling2D(size=(2, 2))([conv5, idx3])
    conv6 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(up6)
    conv6 = keras.layers.BatchNormalization()(conv6)
    conv6 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(conv6)
    conv6 = keras.layers.BatchNormalization()(conv6)

    up7 = MaxUnpooling2D(size=(2, 2))([conv6, idx2])
    conv7 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(up7)
    conv7 = keras.layers.BatchNormalization()(conv7)
    conv7 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = keras.layers.BatchNormalization()(conv7)

    up8 = MaxUnpooling2D(size=(2, 2))([conv7, idx1])
    conv8 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(up8)
    conv8 = keras.layers.BatchNormalization()(conv8)
    conv8 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(conv8)
    conv8 = keras.layers.BatchNormalization()(conv8)

    outputs = keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(conv8)

    model = keras.models.Model(inputs=inputs, outputs=outputs)
    return model

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

预测部分注意事项
直接预测会出现报错:ValueError: Unknown layer: MaxPoolingWithArgmax2D
需要在load_model的时候加入声明哦

model = load_model(inmodel, 
                        custom_objects={
                        'weighted_cross_entropy': weighted_cross_entropy(lossweights),
                        "loss": weighted_cross_entropy(lossweights),
                        'recall': recall, 
                        'precision':precision, 
                        'kappa_metrics':kappa_metrics,
                        'fmeasure':fmeasure, 
                        "lr": get_lr_metric,
                        'OA':OA, 
                        'tf':tf, 
                        'BS':BS, 
                        'img_h':128,
                        'img_w':128, 
                        'n_label':1,
                        "MaxPoolingWithArgmax2D": MaxPoolingWithArgmax2D, 
                        "MaxUnpooling2D": MaxUnpooling2D
                        })
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/641461
推荐阅读
相关标签
  

闽ICP备14008679号