Kaiming He等人提出了残差网络ResNet来解决上述所说的退化问题,其基本思想如图6所示。
图6 残差块设计思想
图6(b)的结构是残差网络的基础,这种结构也叫做残差块(residual block)。输入x通过跨层连接,能更快的向前传播数据,或者向后传播梯度。残差块的具体设计方案如 图7 所示,这种设计方案也成称作瓶颈结构(BottleNeck)。
图7 残差块结构示意图
ResNet-34中跳跃连接“实线”为identity mapping和residual mapping通道数相同,“虚线”部分指的是两者通道数不同,需要使用1x1卷积调整通道维度,使其可以相加。
基本设置遵循以前的经典网络,可以看原文的参考文献。在每次卷积之后和激活之前,我们采用批量归一化(BN) ,紧接着,我们初始化权重,并从头开始训练所有普通/残差网。我们使用最小批量为256的SGD。当误差平稳时,学习率从0.1开始除以10,模型被训练达到600000次迭代。我们使用0.0001的重量衰减和0.9的动量。
from keras.layers import Input from keras.layers import Conv2D, MaxPool2D, Dense, BatchNormalization, Activation, add, GlobalAvgPool2D from keras.models import Model from keras import regularizers from keras.utils import plot_model from keras import backend as K def conv2d_bn(x, nb_filter, kernel_size, strides=(1, 1), padding='same'): """ conv2d -> batch normalization -> relu activation """ x = Conv2D(nb_filter, kernel_size=kernel_size, strides=strides, padding=padding, kernel_regularizer=regularizers.l2(0.0001))(x) x = BatchNormalization()(x) x = Activation('relu')(x) return x def shortcut(input, residual): """ shortcut连接,也就是identity mapping部分。 """ input_shape = K.int_shape(input) residual_shape = K.int_shape(residual) stride_height = int(round(input_shape[1] / residual_shape[1])) stride_width = int(round(input_shape[2] / residual_shape[2])) equal_channels = input_shape[3] == residual_shape[3] identity = input # 如果维度不同,则使用1x1卷积进行调整 if stride_width > 1 or stride_height > 1 or not equal_channels: identity = Conv2D(filters=residual_shape[3], kernel_size=(1, 1), strides=(stride_width, stride_height), padding="valid", kernel_regularizer=regularizers.l2(0.0001))(input) return add([identity, residual]) def basic_block(nb_filter, strides=(1, 1)): """ 基本的ResNet building block,适用于ResNet-18和ResNet-34. """ def f(input): conv1 = conv2d_bn(input, nb_filter, kernel_size=(3, 3), strides=strides) residual = conv2d_bn(conv1, nb_filter, kernel_size=(3, 3)) return shortcut(input, residual) return f def residual_block(nb_filter, repetitions, is_first_layer=False): """ 构建每层的residual模块,对应论文参数统计表中的conv2_x -> conv5_x """ def f(input): for i in range(repetitions): strides = (1, 1) if i == 0 and not is_first_layer: strides = (2, 2) input = basic_block(nb_filter, strides)(input) return input return f def resnet_18(input_shape=(224,224,3), nclass=1000): """ build resnet-18 model using keras with TensorFlow backend. :param input_shape: input shape of network, default as (224,224,3) :param nclass: numbers of class(output shape of network), default as 1000 :return: resnet-18 model """ input_ = Input(shape=input_shape) conv1 = conv2d_bn(input_, 64, kernel_size=(7, 7), strides=(2, 2)) pool1 = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(conv1) conv2 = residual_block(64, 2, is_first_layer=True)(pool1) conv3 = residual_block(128, 2, is_first_layer=True)(conv2) conv4 = residual_block(256, 2, is_first_layer=True)(conv3) conv5 = residual_block(512, 2, is_first_layer=True)(conv4) pool2 = GlobalAvgPool2D()(conv5) output_ = Dense(nclass, activation='softmax')(pool2) model = Model(inputs=input_, outputs=output_) model.summary() return model if __name__ == '__main__': model = resnet_18() plot_model(model, 'ResNet-18.png') # 保存模型图
import keras import argparse import numpy as np from keras.datasets import cifar10, cifar100 from keras.preprocessing.image import ImageDataGenerator from keras.layers.normalization import BatchNormalization from keras.layers import Conv2D, Dense, Input, add, Activation, GlobalAveragePooling2D from keras.callbacks import LearningRateScheduler, TensorBoard, ModelCheckpoint from keras.models import Model from keras import optimizers, regularizers from keras import backend as K # set GPU memory if('tensorflow' == K.backend()): import tensorflow as tf from keras.backend.tensorflow_backend import set_session config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) # set parameters via parser parser = argparse.ArgumentParser() parser.add_argument('-b','--batch_size', type=int, default=128, metavar='NUMBER', help='batch size(default: 128)') parser.add_argument('-e','--epochs', type=int, default=200, metavar='NUMBER', help='epochs(default: 200)') parser.add_argument('-n','--stack_n', type=int, default=5, metavar='NUMBER', help='stack number n, total layers = 6 * n + 2 (default: 5)') parser.add_argument('-d','--dataset', type=str, default="cifar10", metavar='STRING', help='dataset. (default: cifar10)') args = parser.parse_args() stack_n = args.stack_n layers = 6 * stack_n + 2 num_classes = 10 img_rows, img_cols = 32, 32 img_channels = 3 batch_size = args.batch_size epochs = args.epochs iterations = 50000 // batch_size + 1 weight_decay = 1e-4 def color_preprocessing(x_train,x_test): x_train = x_train.astype('float32') x_test = x_test.astype('float32') mean = [125.307, 122.95, 113.865] std = [62.9932, 62.0887, 66.7048] for i in range(3): x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i] x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i] return x_train, x_test def scheduler(epoch): if epoch < 81: return 0.1 if epoch < 122: return 0.01 return 0.001 def residual_network(img_input,classes_num=10,stack_n=5): def residual_block(x,o_filters,increase=False): stride = (1,1) if increase: stride = (2,2) o1 = Activation('relu')(BatchNormalization(momentum=0.9, epsilon=1e-5)(x)) conv_1 = Conv2D(o_filters,kernel_size=(3,3),strides=stride,padding='same', kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(weight_decay))(o1) o2 = Activation('relu')(BatchNormalization(momentum=0.9, epsilon=1e-5)(conv_1)) conv_2 = Conv2D(o_filters,kernel_size=(3,3),strides=(1,1),padding='same', kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(weight_decay))(o2) if increase: projection = Conv2D(o_filters,kernel_size=(1,1),strides=(2,2),padding='same', kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(weight_decay))(o1) block = add([conv_2, projection]) else: block = add([conv_2, x]) return block # build model ( total layers = stack_n * 3 * 2 + 2 ) # stack_n = 5 by default, total layers = 32 # input: 32x32x3 output: 32x32x16 x = Conv2D(filters=16,kernel_size=(3,3),strides=(1,1),padding='same', kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(weight_decay))(img_input) # input: 32x32x16 output: 32x32x16 for _ in range(stack_n): x = residual_block(x,16,False) # input: 32x32x16 output: 16x16x32 x = residual_block(x,32,True) for _ in range(1,stack_n): x = residual_block(x,32,False) # input: 16x16x32 output: 8x8x64 x = residual_block(x,64,True) for _ in range(1,stack_n): x = residual_block(x,64,False) x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x) x = Activation('relu')(x) x = GlobalAveragePooling2D()(x) # input: 64 output: 10 x = Dense(classes_num,activation='softmax',kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(weight_decay))(x) return x if __name__ == '__main__': print("========================================") print("MODEL: Residual Network ({:2d} layers)".format(6*stack_n+2)) print("BATCH SIZE: {:3d}".format(batch_size)) print("WEIGHT DECAY: {:.4f}".format(weight_decay)) print("EPOCHS: {:3d}".format(epochs)) print("DATASET: {:}".format(args.dataset)) print("== LOADING DATA... ==") # load data global num_classes if args.dataset == "cifar100": num_classes = 100 (x_train, y_train), (x_test, y_test) = cifar100.load_data() else: (x_train, y_train), (x_test, y_test) = cifar10.load_data() y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) print("== DONE! ==\n== COLOR PREPROCESSING... ==") # color preprocessing x_train, x_test = color_preprocessing(x_train, x_test) print("== DONE! ==\n== BUILD MODEL... ==") # build network img_input = Input(shape=(img_rows,img_cols,img_channels)) output = residual_network(img_input,num_classes,stack_n) resnet = Model(img_input, output) # print model architecture if you need. # print(resnet.summary()) # set optimizer sgd = optimizers.SGD(lr=.1, momentum=0.9, nesterov=True) resnet.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) # set callback cbks = [TensorBoard(log_dir='./resnet_{:d}_{}/'.format(layers,args.dataset), histogram_freq=0), LearningRateScheduler(scheduler)] # dump checkpoint if you need.(add it to cbks) # ModelCheckpoint('./checkpoint-{epoch}.h5', save_best_only=False, mode='auto', period=10) # set data augmentation print("== USING REAL-TIME DATA AUGMENTATION, START TRAIN... ==") datagen = ImageDataGenerator(horizontal_flip=True, width_shift_range=0.125, height_shift_range=0.125, fill_mode='constant',cval=0.) datagen.fit(x_train) # start training resnet.fit_generator(datagen.flow(x_train, y_train,batch_size=batch_size), steps_per_epoch=iterations, epochs=epochs, callbacks=cbks, validation_data=(x_test, y_test)) resnet.save('resnet_{:d}_{}.h5'.format(layers,args.dataset))
