赞
踩
网上对于何凯明等大佬写的Residual论文的解读已经够多了,经过一段时间的学习,我想摸索出一套适合所有resnet类型的通用函数,以便实验,故而在这篇博客中,我重点讲如何实现通用函数。
重点1:
在上面图中,我们需要注意 F(x) + x 是在 激活函数relu之前进行的,知道这一点是为了实现卷积函数conv2D_BN时,先不进行激活。
重点2:
我们通过观察可知 只有在每种残差快的第一个块,其shortcut连接才需要 1x1 卷积,其相应的连接也是虚线,比如3处,但是有一处需要注意,即 1 处,它也是一种残差块的第一块,但是不需要1x1 卷积。其连接是实线,而类似于 2处 的 一种残差块内部的连接,也是不需要1x1 卷积,其连接也是实线。
重点3:
为了便于建立通用的函数,适合所有的残差类型,我们需要创建一种表,来对应上述的内容。如下所示:
- # 50-layer
- #比如[3, [[64,(1,1)],[64,(3,3)],[256,(1,1)]], 第一参数 3 是 一种残差块的个数,
- #第二个参数是此种残差块的卷积层的一些参数,其是过滤器个数filtes,kernel_size
- filter_list_resnet50 = [ [3, [[64,(1,1)],[64,(3,3)],[256,(1,1)]] ],
- [4, [[128,(1,1)],[128,(3,3)],[512,(1,1)]] ],
- [6, [[256,(1,1)],[256,(3,3)],[1024,(1,1)]] ],
- [3, [[512,(1,1)],[512,(3,3)],[2048,(1,1)]] ]]
- # 18-layer
- filter_list_resnet18 = [ [2, [[64,(3,3)],[64,(3,3)]] ],
- [2, [[128,(3,3)],[128,(3,3)]] ],
- [2, [[256,(3,3)],[256,(3,3)]] ],
- [2, [[512,(3,3)],[512,(3,3)]] ]]
重点4:
有时候特征图大小不变,有时候减半,其对应的padding就有可能不同。
除了1x1卷积用padding=‘valid’之外,其他的都用padding=‘same’。
第一步,导入库
- import tensorflow as tf
- from tensorflow import keras
- import numpy as np
- import matplotlib.pyplot as plt
第二步,实现卷积归一化
- def conv2D_BN(x, num_filter, kernel_size, strides=(1,1), padding='same'):
- '''
- 为了方便 F(x) + x ,之后再relu激活
- 故此卷积没有激活函数
- '''
- conv = keras.layers.Conv2D(filters=num_filter, kernel_size=kernel_size,
- strides=strides, padding=padding,
- kernel_regularizer=keras.regularizers.l2(0.0001))(x)
- bn = keras.layers.BatchNormalization()(conv)
- return bn
第三步,实现基本残差块
从论文中,可以知晓,若特征图大小output map size不变,那么过滤器数目不变;若大小减半,则过滤器数目加倍。前者对应building_block内部,后者对应building_block之间。
由此,我们可以得到padding='same'。步长strides=1,则特征图大小不变,步长为2,则大小减半。
- def building_block(x, filters_list, is_first_layers=False):
- '''
- 这是一个基本残差块,适用于任何残差块类型。
- is_first_layers=True,说明此时步长strides=2,特征图大小需要减半,
- 否则步长为1,特征图大小不变;
- 同时也说明是shortcut是需要 1x1卷积的,即shortcut虚线部分;否则无需改变。
- filter_list: 包含若干个列表,每个列表包含一种类型的残差块,其信息如下:
- 此类残差块个数,[过滤器数目,核大小],[过滤器数目,核大小],,,
- '''
- y = x
- strides=(1,1)
- for i in range(len(filters_list)):
- if is_first_layers and i == 0:
- strides=(2,2)
- else:
- strides=(1,1)
- y = conv2D_BN(y, filters_list[i][0],kernel_size=filters_list[i][1],strides=strides)
- # short_cut
- '''
- is_first_layers为True,并且为残差块的最后一层
- 此时说明需要1x1卷积,改变x即input的特征图大小,即减半,步长为2。其过滤器数目
- filters需要同最后一层即当前层的过滤器数目相同,即filters=filters_list[i][0]
- '''
- if is_first_layers and i == len(filters_list) - 1:
- x = conv2D_BN(x, filters_list[i][0],kernel_size=(1,1),
- strides=(2,2), padding='valid')
- break
- #若是残差块的最后一层,则先不需要激活,先进行相加操作,即残差块的输入和输出相加
- #其他情况,即残差块的内部层之间,可以直接激活
- if i == len(filters_list) - 1:
- break
- y = keras.layers.Activation('relu')(y)
- f = keras.layers.add([x, y])
- return keras.layers.Activation('relu')(y)
第四步,实现残差网络主体区域,即不同的地方
- def residual_main_network(x, filter_list_resnet):
- for i in range(len(filter_list_resnet)):
- for j in range(filter_list_resnet[i][0]):
- #倘若是一种类型残差块的第一个块,即j==0,且不能是第一种残差块,因为第一种残差块
- #不需要shortcut,即 i != 0
- if j == 0 and i != 0:
- is_first_layers=True
- else:
- is_first_layers=False
- x = building_block(x, filters_list=filter_list_resnet[i][1],
- is_first_layers=is_first_layers)
- return x
第五步,实现残差网络
- def resnet(nclass,input_shape, filter_list_resnet): #nclass是输出种类数,input_shape是输入形状
- input_ = keras.layers.Input(shape=input_shape)
- conv1 = conv2D_BN(input_, 64, kernel_size=(7,7), strides=(2,2))
- conv1 = keras.layers.Activation('relu')(conv1)
- pool1 = keras.layers.MaxPool2D(pool_size=(3, 3),strides=(2, 2),padding='same')(conv1)
-
- conv2 = residual_main_network(pool1, filter_list_resnet)
-
- pool2 = keras.layers.GlobalAvgPool2D()(conv2)
- output_ = keras.layers.Dense(nclass, 'softmax')(pool2)
-
- model = keras.Model(inputs=input_,outputs=output_)
- return model
- filter_list_resnet18 = [ [2, [[64,(3,3)],[64,(3,3)]] ],
- [2, [[128,(3,3)],[128,(3,3)]] ],
- [2, [[256,(3,3)],[256,(3,3)]] ],
- [2, [[512,(3,3)],[512,(3,3)]] ]]
- model = resnet(10, (32,32,3), filter_list_resnet18)
- model.summary()
运行结果:
Model: "model_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 32, 32, 3)] 0 _________________________________________________________________ conv2d_66 (Conv2D) (None, 16, 16, 64) 9472 _________________________________________________________________ batch_normalization_66 (Batc (None, 16, 16, 64) 256 _________________________________________________________________ activation_55 (Activation) (None, 16, 16, 64) 0 _________________________________________________________________ max_pooling2d_4 (MaxPooling2 (None, 8, 8, 64) 0 _________________________________________________________________ conv2d_67 (Conv2D) (None, 8, 8, 64) 36928 _________________________________________________________________ batch_normalization_67 (Batc (None, 8, 8, 64) 256 _________________________________________________________________ activation_56 (Activation) (None, 8, 8, 64) 0 _________________________________________________________________ conv2d_68 (Conv2D) (None, 8, 8, 64) 36928 _________________________________________________________________ batch_normalization_68 (Batc (None, 8, 8, 64) 256 _________________________________________________________________ activation_57 (Activation) (None, 8, 8, 64) 0 _________________________________________________________________ conv2d_69 (Conv2D) (None, 8, 8, 64) 36928 _________________________________________________________________ batch_normalization_69 (Batc (None, 8, 8, 64) 256 _________________________________________________________________ activation_58 (Activation) (None, 8, 8, 64) 0 _________________________________________________________________ conv2d_70 (Conv2D) (None, 8, 8, 64) 36928 _________________________________________________________________ batch_normalization_70 (Batc (None, 8, 8, 64) 256 _________________________________________________________________ activation_59 (Activation) (None, 8, 8, 64) 0 _________________________________________________________________ conv2d_71 (Conv2D) (None, 4, 4, 128) 73856 _________________________________________________________________ batch_normalization_71 (Batc (None, 4, 4, 128) 512 _________________________________________________________________ activation_60 (Activation) (None, 4, 4, 128) 0 _________________________________________________________________ conv2d_72 (Conv2D) (None, 4, 4, 128) 147584 _________________________________________________________________ batch_normalization_72 (Batc (None, 4, 4, 128) 512 _________________________________________________________________ activation_61 (Activation) (None, 4, 4, 128) 0 _________________________________________________________________ conv2d_74 (Conv2D) (None, 4, 4, 128) 147584 _________________________________________________________________ batch_normalization_74 (Batc (None, 4, 4, 128) 512 _________________________________________________________________ activation_62 (Activation) (None, 4, 4, 128) 0 _________________________________________________________________ conv2d_75 (Conv2D) (None, 4, 4, 128) 147584 _________________________________________________________________ batch_normalization_75 (Batc (None, 4, 4, 128) 512 _________________________________________________________________ activation_63 (Activation) (None, 4, 4, 128) 0 _________________________________________________________________ conv2d_76 (Conv2D) (None, 2, 2, 256) 295168 _________________________________________________________________ batch_normalization_76 (Batc (None, 2, 2, 256) 1024 _________________________________________________________________ activation_64 (Activation) (None, 2, 2, 256) 0 _________________________________________________________________ conv2d_77 (Conv2D) (None, 2, 2, 256) 590080 _________________________________________________________________ batch_normalization_77 (Batc (None, 2, 2, 256) 1024 _________________________________________________________________ activation_65 (Activation) (None, 2, 2, 256) 0 _________________________________________________________________ conv2d_79 (Conv2D) (None, 2, 2, 256) 590080 _________________________________________________________________ batch_normalization_79 (Batc (None, 2, 2, 256) 1024 _________________________________________________________________ activation_66 (Activation) (None, 2, 2, 256) 0 _________________________________________________________________ conv2d_80 (Conv2D) (None, 2, 2, 256) 590080 _________________________________________________________________ batch_normalization_80 (Batc (None, 2, 2, 256) 1024 _________________________________________________________________ activation_67 (Activation) (None, 2, 2, 256) 0 _________________________________________________________________ conv2d_81 (Conv2D) (None, 1, 1, 512) 1180160 _________________________________________________________________ batch_normalization_81 (Batc (None, 1, 1, 512) 2048 _________________________________________________________________ activation_68 (Activation) (None, 1, 1, 512) 0 _________________________________________________________________ conv2d_82 (Conv2D) (None, 1, 1, 512) 2359808 _________________________________________________________________ batch_normalization_82 (Batc (None, 1, 1, 512) 2048 _________________________________________________________________ activation_69 (Activation) (None, 1, 1, 512) 0 _________________________________________________________________ conv2d_84 (Conv2D) (None, 1, 1, 512) 2359808 _________________________________________________________________ batch_normalization_84 (Batc (None, 1, 1, 512) 2048 _________________________________________________________________ activation_70 (Activation) (None, 1, 1, 512) 0 _________________________________________________________________ conv2d_85 (Conv2D) (None, 1, 1, 512) 2359808 _________________________________________________________________ batch_normalization_85 (Batc (None, 1, 1, 512) 2048 _________________________________________________________________ activation_71 (Activation) (None, 1, 1, 512) 0 _________________________________________________________________ global_average_pooling2d_3 ( (None, 512) 0 _________________________________________________________________ dense_3 (Dense) (None, 10) 5130 ================================================================= Total params: 11,019,530 Trainable params: 11,011,722 Non-trainable params: 7,808
下面用这个模型去对cifar10进行训练
第一步:编译模型参数和导入数据集并预处理
- model.compile(optimizer=tf.optimizers.Adam(0.001),
- loss=tf.losses.SparseCategoricalCrossentropy(),
- metrics=['acc'])
-
- from keras.datasets import cifar10
- (x_train, y_train), (x_val, y_val) = cifar10.load_data()
- x_train = x_train / 255
- x_val = x_val / 255
第二步:观察数据集
- print(x_train.shape)
-
- plt.figure()
- plt.imshow(x_train[0])
- plt.show()
第三步:拟合数据集,训练网络
- model.fit(x_train,y_train,validation_data=(x_val,y_val),epochs=10,
- batch_size=64)
运行结果如下:
Epoch 1/10 782/782 [==============================] - 1038s 1s/step - loss: 1.7900 - acc: 0.4212 - val_loss: 2.2747 - val_acc: 0.3473 Epoch 2/10 782/782 [==============================] - 1084s 1s/step - loss: 1.4167 - acc: 0.5629 - val_loss: 1.6816 - val_acc: 0.4755 Epoch 3/10 782/782 [==============================] - 1047s 1s/step - loss: 1.2337 - acc: 0.6355 - val_loss: 1.9268 - val_acc: 0.4499 Epoch 4/10 782/782 [==============================] - 1059s 1s/step - loss: 1.1222 - acc: 0.6760 - val_loss: 1.4456 - val_acc: 0.5592 Epoch 5/10 782/782 [==============================] - 1075s 1s/step - loss: 1.0435 - acc: 0.7047 - val_loss: 1.7463 - val_acc: 0.5160 Epoch 6/10 782/782 [==============================] - 1094s 1s/step - loss: 0.9957 - acc: 0.7297 - val_loss: 1.9739 - val_acc: 0.5149 Epoch 7/10 782/782 [==============================] - 1109s 1s/step - loss: 0.9553 - acc: 0.7510 - val_loss: 1.3359 - val_acc: 0.6366 Epoch 8/10 782/782 [==============================] - 1120s 1s/step - loss: 0.9221 - acc: 0.7681 - val_loss: 1.3839 - val_acc: 0.6401 Epoch 9/10 782/782 [==============================] - 1129s 1s/step - loss: 0.8882 - acc: 0.7852 - val_loss: 1.1889 - val_acc: 0.6920 Epoch 10/10 782/782 [==============================] - 1137s 1s/step - loss: 0.8584 - acc: 0.8003 - val_loss: 1.3718 - val_acc: 0.6465
由于电脑不咋地,所以一些参数没有优化,你如正则化,epochs大小,batch_size等等。
如有错误,欢迎指正‘‘‘‘‘’’’’’
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。