赞
踩
我真服了原来我之前用tf复现SegNet给复现错了
在网上试了多个版本代码,折腾了好久,现在终于复现对了,代码也跑通了
SegNet的架构比较老了,这几年都没人更新代码了,我这里算是提供一个最近能跑通的版本的代码吧
tf版本2.4.1
首先主要是构建两个类来实现池化索引,这里经过反复尝试我懵懵懂懂地解决了其它代码直接搬运过来导致的各种报错
import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras.layers import Layer class MaxPoolingWithArgmax2D(Layer): def __init__(self, pool_size=(2, 2), strides=(2, 2), padding='same', **kwargs): super(MaxPoolingWithArgmax2D, self).__init__(**kwargs) self.padding = padding self.pool_size = pool_size self.strides = strides def call(self, inputs, **kwargs): padding = self.padding pool_size = self.pool_size strides = self.strides if K.backend() == 'tensorflow': ksize = [1, pool_size[0], pool_size[1], 1] padding = padding.upper() strides = [1, strides[0], strides[1], 1] output, argmax = tf.nn.max_pool_with_argmax(inputs, ksize=ksize, strides=strides, padding=padding) else: errmsg = '{} backend is not supported for layer {}'.format(K.backend(), type(self).__name__) raise NotImplementedError(errmsg) argmax = K.cast(argmax, K.floatx()) return [output, argmax] def compute_output_shape(self, input_shape): ratio = (1, 2, 2, 1) output_shape = [dim // ratio[idx] if dim is not None else None for idx, dim in enumerate(input_shape)] output_shape = tuple(output_shape) return [output_shape, output_shape] def compute_mask(self, inputs, mask=None): return 2 * [None] def get_config(self): config = super(MaxPoolingWithArgmax2D, self).get_config() config.update({ "pool_size": self.pool_size, "strides": self.strides, "padding": self.padding, }) return config class MaxUnpooling2D(Layer): def __init__(self, size=(2, 2), **kwargs): super(MaxUnpooling2D, self).__init__(**kwargs) self.size = size def call(self, inputs, output_shape=None): updates, mask = inputs[0], inputs[1] with tf.compat.v1.variable_scope(self.name): mask = K.cast(mask, 'int32') input_shape = tf.shape(updates, out_type='int32') # calculation new shape if output_shape is None: output_shape = (input_shape[0], input_shape[1] * self.size[0], input_shape[2] * self.size[1], input_shape[3]) self.output_shape1 = output_shape # calculation indices for batch, height, width and feature maps one_like_mask = K.ones_like(mask, dtype='int32') batch_shape = K.concatenate([[input_shape[0]], [1], [1], [1]], axis=0) batch_range = K.reshape(tf.range(output_shape[0], dtype='int32'), shape=batch_shape) b = one_like_mask * batch_range y = mask // (output_shape[2] * output_shape[3]) x = (mask // output_shape[3]) % output_shape[2] feature_range = tf.range(output_shape[3], dtype='int32') f = one_like_mask * feature_range # transpose indices & reshape update values to one dimension updates_size = tf.size(updates) indices = K.transpose(K.reshape(K.stack([b, y, x, f]), [4, updates_size])) values = K.reshape(updates, [updates_size]) ret = tf.scatter_nd(indices, values, output_shape) input_shape = updates.shape out_shape = [-1, input_shape[1] * self.size[0], input_shape[2] * self.size[1], input_shape[3]] return K.reshape(ret, out_shape) def compute_output_shape(self, input_shape): mask_shape = input_shape[1] return mask_shape[0], mask_shape[1] * self.size[0], mask_shape[2] * self.size[1], mask_shape[3] def get_config(self): config = super(MaxUnpooling2D, self).get_config() config.update({ "size": self.size, }) return config
另外SegNet网络主体部分,注意池化和反池化的时候filters数量要对得上
def SegNet(fNum, dates, lossweights, filters=64): inputs = keras.layers.Input((fNum*dates, img_h, img_w)) inputs0 = keras.layers.Lambda(reshapes2)(inputs) # 针对我数据的reshape # Encoder conv1 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(inputs0) conv1 = keras.layers.BatchNormalization()(conv1) conv1 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(conv1) conv1 = keras.layers.BatchNormalization()(conv1) pool1, idx1 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv1) conv2 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(pool1) conv2 = keras.layers.BatchNormalization()(conv2) conv2 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(conv2) conv2 = keras.layers.BatchNormalization()(conv2) pool2, idx2 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv2) conv3 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(pool2) conv3 = keras.layers.BatchNormalization()(conv3) conv3 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(conv3) conv3 = keras.layers.BatchNormalization()(conv3) pool3, idx3 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv3) conv4 = keras.layers.Conv2D(filters*8, (3, 3), activation='relu', padding='same')(pool3) conv4 = keras.layers.BatchNormalization()(conv4) conv4 = keras.layers.Conv2D(filters*8, (3, 3), activation='relu', padding='same')(conv4) conv4 = keras.layers.BatchNormalization()(conv4) pool4, idx4 = MaxPoolingWithArgmax2D(pool_size=(2, 2))(conv4) # Decoder up5 = MaxUnpooling2D((2,2))([pool4, idx4]) conv5 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(up5) conv5 = keras.layers.BatchNormalization()(conv5) conv5 = keras.layers.Conv2D(filters*4, (3, 3), activation='relu', padding='same')(conv5) conv5 = keras.layers.BatchNormalization()(conv5) up6 = MaxUnpooling2D(size=(2, 2))([conv5, idx3]) conv6 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(up6) conv6 = keras.layers.BatchNormalization()(conv6) conv6 = keras.layers.Conv2D(filters*2, (3, 3), activation='relu', padding='same')(conv6) conv6 = keras.layers.BatchNormalization()(conv6) up7 = MaxUnpooling2D(size=(2, 2))([conv6, idx2]) conv7 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(up7) conv7 = keras.layers.BatchNormalization()(conv7) conv7 = keras.layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(conv7) conv7 = keras.layers.BatchNormalization()(conv7) up8 = MaxUnpooling2D(size=(2, 2))([conv7, idx1]) conv8 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(up8) conv8 = keras.layers.BatchNormalization()(conv8) conv8 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(conv8) conv8 = keras.layers.BatchNormalization()(conv8) outputs = keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(conv8) model = keras.models.Model(inputs=inputs, outputs=outputs) return model
预测部分注意事项
直接预测会出现报错:ValueError: Unknown layer: MaxPoolingWithArgmax2D
需要在load_model的时候加入声明哦
model = load_model(inmodel, custom_objects={ 'weighted_cross_entropy': weighted_cross_entropy(lossweights), "loss": weighted_cross_entropy(lossweights), 'recall': recall, 'precision':precision, 'kappa_metrics':kappa_metrics, 'fmeasure':fmeasure, "lr": get_lr_metric, 'OA':OA, 'tf':tf, 'BS':BS, 'img_h':128, 'img_w':128, 'n_label':1, "MaxPoolingWithArgmax2D": MaxPoolingWithArgmax2D, "MaxUnpooling2D": MaxUnpooling2D })
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。