当前位置:   article > 正文

Resnet实现,用通用函数实现全部的resnet类型,如resnet18,resnet34,resnet50等等_resnet18和34哪个好

resnet18和34哪个好

一、准备

网上对于何凯明等大佬写的Residual论文的解读已经够多了,经过一段时间的学习,我想摸索出一套适合所有resnet类型的通用函数,以便实验,故而在这篇博客中,我重点讲如何实现通用函数。

重点1:

 在上面图中,我们需要注意 F(x) + x 是在 激活函数relu之前进行的,知道这一点是为了实现卷积函数conv2D_BN时,先不进行激活。

重点2:

我们通过观察可知 只有在每种残差快的第一个块,其shortcut连接才需要 1x1 卷积,其相应的连接也是虚线,比如3处,但是有一处需要注意,即 1 处,它也是一种残差块的第一块,但是不需要1x1 卷积。其连接是实线,而类似于 2处 的 一种残差块内部的连接,也是不需要1x1 卷积,其连接也是实线。

重点3:

 为了便于建立通用的函数,适合所有的残差类型,我们需要创建一种表,来对应上述的内容。如下所示:

  1. # 50-layer
  2. #比如[3, [[64,(1,1)],[64,(3,3)],[256,(1,1)]], 第一参数 3 是 一种残差块的个数,
  3. #第二个参数是此种残差块的卷积层的一些参数,其是过滤器个数filtes,kernel_size
  4. filter_list_resnet50 = [ [3, [[64,(1,1)],[64,(3,3)],[256,(1,1)]] ],
  5. [4, [[128,(1,1)],[128,(3,3)],[512,(1,1)]] ],
  6. [6, [[256,(1,1)],[256,(3,3)],[1024,(1,1)]] ],
  7. [3, [[512,(1,1)],[512,(3,3)],[2048,(1,1)]] ]]
  8. # 18-layer
  9. filter_list_resnet18 = [ [2, [[64,(3,3)],[64,(3,3)]] ],
  10. [2, [[128,(3,3)],[128,(3,3)]] ],
  11. [2, [[256,(3,3)],[256,(3,3)]] ],
  12. [2, [[512,(3,3)],[512,(3,3)]] ]]

 重点4:

 有时候特征图大小不变,有时候减半,其对应的padding就有可能不同。

除了1x1卷积用padding=‘valid’之外,其他的都用padding=‘same’。

二、实现

第一步,导入库

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. import numpy as np
  4. import matplotlib.pyplot as plt

 第二步,实现卷积归一化

  1. def conv2D_BN(x, num_filter, kernel_size, strides=(1,1), padding='same'):
  2. '''
  3. 为了方便 F(x) + x ,之后再relu激活
  4. 故此卷积没有激活函数
  5. '''
  6. conv = keras.layers.Conv2D(filters=num_filter, kernel_size=kernel_size,
  7. strides=strides, padding=padding,
  8. kernel_regularizer=keras.regularizers.l2(0.0001))(x)
  9. bn = keras.layers.BatchNormalization()(conv)
  10. return bn

第三步,实现基本残差块

从论文中,可以知晓,若特征图大小output map size不变,那么过滤器数目不变;若大小减半,则过滤器数目加倍。前者对应building_block内部,后者对应building_block之间。

由此,我们可以得到padding='same'。步长strides=1,则特征图大小不变,步长为2,则大小减半。

  1. def building_block(x, filters_list, is_first_layers=False):
  2. '''
  3. 这是一个基本残差块,适用于任何残差块类型。
  4. is_first_layers=True,说明此时步长strides=2,特征图大小需要减半,
  5. 否则步长为1,特征图大小不变;
  6. 同时也说明是shortcut是需要 1x1卷积的,即shortcut虚线部分;否则无需改变。
  7. filter_list: 包含若干个列表,每个列表包含一种类型的残差块,其信息如下:
  8. 此类残差块个数,[过滤器数目,核大小],[过滤器数目,核大小],,,
  9. '''
  10. y = x
  11. strides=(1,1)
  12. for i in range(len(filters_list)):
  13. if is_first_layers and i == 0:
  14. strides=(2,2)
  15. else:
  16. strides=(1,1)
  17. y = conv2D_BN(y, filters_list[i][0],kernel_size=filters_list[i][1],strides=strides)
  18. # short_cut
  19. '''
  20. is_first_layers为True,并且为残差块的最后一层
  21. 此时说明需要1x1卷积,改变x即input的特征图大小,即减半,步长为2。其过滤器数目
  22. filters需要同最后一层即当前层的过滤器数目相同,即filters=filters_list[i][0]
  23. '''
  24. if is_first_layers and i == len(filters_list) - 1:
  25. x = conv2D_BN(x, filters_list[i][0],kernel_size=(1,1),
  26. strides=(2,2), padding='valid')
  27. break
  28. #若是残差块的最后一层,则先不需要激活,先进行相加操作,即残差块的输入和输出相加
  29. #其他情况,即残差块的内部层之间,可以直接激活
  30. if i == len(filters_list) - 1:
  31. break
  32. y = keras.layers.Activation('relu')(y)
  33. f = keras.layers.add([x, y])
  34. return keras.layers.Activation('relu')(y)

第四步,实现残差网络主体区域,即不同的地方

  1. def residual_main_network(x, filter_list_resnet):
  2. for i in range(len(filter_list_resnet)):
  3. for j in range(filter_list_resnet[i][0]):
  4. #倘若是一种类型残差块的第一个块,即j==0,且不能是第一种残差块,因为第一种残差块
  5. #不需要shortcut,即 i != 0
  6. if j == 0 and i != 0:
  7. is_first_layers=True
  8. else:
  9. is_first_layers=False
  10. x = building_block(x, filters_list=filter_list_resnet[i][1],
  11. is_first_layers=is_first_layers)
  12. return x

第五步,实现残差网络

  1. def resnet(nclass,input_shape, filter_list_resnet): #nclass是输出种类数,input_shape是输入形状
  2. input_ = keras.layers.Input(shape=input_shape)
  3. conv1 = conv2D_BN(input_, 64, kernel_size=(7,7), strides=(2,2))
  4. conv1 = keras.layers.Activation('relu')(conv1)
  5. pool1 = keras.layers.MaxPool2D(pool_size=(3, 3),strides=(2, 2),padding='same')(conv1)
  6. conv2 = residual_main_network(pool1, filter_list_resnet)
  7. pool2 = keras.layers.GlobalAvgPool2D()(conv2)
  8. output_ = keras.layers.Dense(nclass, 'softmax')(pool2)
  9. model = keras.Model(inputs=input_,outputs=output_)
  10. return model

三、举例:用上述模型实现resnet18

  1. filter_list_resnet18 = [ [2, [[64,(3,3)],[64,(3,3)]] ],
  2. [2, [[128,(3,3)],[128,(3,3)]] ],
  3. [2, [[256,(3,3)],[256,(3,3)]] ],
  4. [2, [[512,(3,3)],[512,(3,3)]] ]]
  5. model = resnet(10, (32,32,3), filter_list_resnet18)
  6. model.summary()

运行结果:

  1. Model: "model_3"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. input_5 (InputLayer) [(None, 32, 32, 3)] 0
  6. _________________________________________________________________
  7. conv2d_66 (Conv2D) (None, 16, 16, 64) 9472
  8. _________________________________________________________________
  9. batch_normalization_66 (Batc (None, 16, 16, 64) 256
  10. _________________________________________________________________
  11. activation_55 (Activation) (None, 16, 16, 64) 0
  12. _________________________________________________________________
  13. max_pooling2d_4 (MaxPooling2 (None, 8, 8, 64) 0
  14. _________________________________________________________________
  15. conv2d_67 (Conv2D) (None, 8, 8, 64) 36928
  16. _________________________________________________________________
  17. batch_normalization_67 (Batc (None, 8, 8, 64) 256
  18. _________________________________________________________________
  19. activation_56 (Activation) (None, 8, 8, 64) 0
  20. _________________________________________________________________
  21. conv2d_68 (Conv2D) (None, 8, 8, 64) 36928
  22. _________________________________________________________________
  23. batch_normalization_68 (Batc (None, 8, 8, 64) 256
  24. _________________________________________________________________
  25. activation_57 (Activation) (None, 8, 8, 64) 0
  26. _________________________________________________________________
  27. conv2d_69 (Conv2D) (None, 8, 8, 64) 36928
  28. _________________________________________________________________
  29. batch_normalization_69 (Batc (None, 8, 8, 64) 256
  30. _________________________________________________________________
  31. activation_58 (Activation) (None, 8, 8, 64) 0
  32. _________________________________________________________________
  33. conv2d_70 (Conv2D) (None, 8, 8, 64) 36928
  34. _________________________________________________________________
  35. batch_normalization_70 (Batc (None, 8, 8, 64) 256
  36. _________________________________________________________________
  37. activation_59 (Activation) (None, 8, 8, 64) 0
  38. _________________________________________________________________
  39. conv2d_71 (Conv2D) (None, 4, 4, 128) 73856
  40. _________________________________________________________________
  41. batch_normalization_71 (Batc (None, 4, 4, 128) 512
  42. _________________________________________________________________
  43. activation_60 (Activation) (None, 4, 4, 128) 0
  44. _________________________________________________________________
  45. conv2d_72 (Conv2D) (None, 4, 4, 128) 147584
  46. _________________________________________________________________
  47. batch_normalization_72 (Batc (None, 4, 4, 128) 512
  48. _________________________________________________________________
  49. activation_61 (Activation) (None, 4, 4, 128) 0
  50. _________________________________________________________________
  51. conv2d_74 (Conv2D) (None, 4, 4, 128) 147584
  52. _________________________________________________________________
  53. batch_normalization_74 (Batc (None, 4, 4, 128) 512
  54. _________________________________________________________________
  55. activation_62 (Activation) (None, 4, 4, 128) 0
  56. _________________________________________________________________
  57. conv2d_75 (Conv2D) (None, 4, 4, 128) 147584
  58. _________________________________________________________________
  59. batch_normalization_75 (Batc (None, 4, 4, 128) 512
  60. _________________________________________________________________
  61. activation_63 (Activation) (None, 4, 4, 128) 0
  62. _________________________________________________________________
  63. conv2d_76 (Conv2D) (None, 2, 2, 256) 295168
  64. _________________________________________________________________
  65. batch_normalization_76 (Batc (None, 2, 2, 256) 1024
  66. _________________________________________________________________
  67. activation_64 (Activation) (None, 2, 2, 256) 0
  68. _________________________________________________________________
  69. conv2d_77 (Conv2D) (None, 2, 2, 256) 590080
  70. _________________________________________________________________
  71. batch_normalization_77 (Batc (None, 2, 2, 256) 1024
  72. _________________________________________________________________
  73. activation_65 (Activation) (None, 2, 2, 256) 0
  74. _________________________________________________________________
  75. conv2d_79 (Conv2D) (None, 2, 2, 256) 590080
  76. _________________________________________________________________
  77. batch_normalization_79 (Batc (None, 2, 2, 256) 1024
  78. _________________________________________________________________
  79. activation_66 (Activation) (None, 2, 2, 256) 0
  80. _________________________________________________________________
  81. conv2d_80 (Conv2D) (None, 2, 2, 256) 590080
  82. _________________________________________________________________
  83. batch_normalization_80 (Batc (None, 2, 2, 256) 1024
  84. _________________________________________________________________
  85. activation_67 (Activation) (None, 2, 2, 256) 0
  86. _________________________________________________________________
  87. conv2d_81 (Conv2D) (None, 1, 1, 512) 1180160
  88. _________________________________________________________________
  89. batch_normalization_81 (Batc (None, 1, 1, 512) 2048
  90. _________________________________________________________________
  91. activation_68 (Activation) (None, 1, 1, 512) 0
  92. _________________________________________________________________
  93. conv2d_82 (Conv2D) (None, 1, 1, 512) 2359808
  94. _________________________________________________________________
  95. batch_normalization_82 (Batc (None, 1, 1, 512) 2048
  96. _________________________________________________________________
  97. activation_69 (Activation) (None, 1, 1, 512) 0
  98. _________________________________________________________________
  99. conv2d_84 (Conv2D) (None, 1, 1, 512) 2359808
  100. _________________________________________________________________
  101. batch_normalization_84 (Batc (None, 1, 1, 512) 2048
  102. _________________________________________________________________
  103. activation_70 (Activation) (None, 1, 1, 512) 0
  104. _________________________________________________________________
  105. conv2d_85 (Conv2D) (None, 1, 1, 512) 2359808
  106. _________________________________________________________________
  107. batch_normalization_85 (Batc (None, 1, 1, 512) 2048
  108. _________________________________________________________________
  109. activation_71 (Activation) (None, 1, 1, 512) 0
  110. _________________________________________________________________
  111. global_average_pooling2d_3 ( (None, 512) 0
  112. _________________________________________________________________
  113. dense_3 (Dense) (None, 10) 5130
  114. =================================================================
  115. Total params: 11,019,530
  116. Trainable params: 11,011,722
  117. Non-trainable params: 7,808

下面用这个模型去对cifar10进行训练

第一步:编译模型参数和导入数据集并预处理

  1. model.compile(optimizer=tf.optimizers.Adam(0.001),
  2. loss=tf.losses.SparseCategoricalCrossentropy(),
  3. metrics=['acc'])
  4. from keras.datasets import cifar10
  5. (x_train, y_train), (x_val, y_val) = cifar10.load_data()
  6. x_train = x_train / 255
  7. x_val = x_val / 255

第二步:观察数据集

  1. print(x_train.shape)
  2. plt.figure()
  3. plt.imshow(x_train[0])
  4. plt.show()

第三步:拟合数据集,训练网络

  1. model.fit(x_train,y_train,validation_data=(x_val,y_val),epochs=10,
  2. batch_size=64)

运行结果如下:

  1. Epoch 1/10
  2. 782/782 [==============================] - 1038s 1s/step - loss: 1.7900 - acc: 0.4212 - val_loss: 2.2747 - val_acc: 0.3473
  3. Epoch 2/10
  4. 782/782 [==============================] - 1084s 1s/step - loss: 1.4167 - acc: 0.5629 - val_loss: 1.6816 - val_acc: 0.4755
  5. Epoch 3/10
  6. 782/782 [==============================] - 1047s 1s/step - loss: 1.2337 - acc: 0.6355 - val_loss: 1.9268 - val_acc: 0.4499
  7. Epoch 4/10
  8. 782/782 [==============================] - 1059s 1s/step - loss: 1.1222 - acc: 0.6760 - val_loss: 1.4456 - val_acc: 0.5592
  9. Epoch 5/10
  10. 782/782 [==============================] - 1075s 1s/step - loss: 1.0435 - acc: 0.7047 - val_loss: 1.7463 - val_acc: 0.5160
  11. Epoch 6/10
  12. 782/782 [==============================] - 1094s 1s/step - loss: 0.9957 - acc: 0.7297 - val_loss: 1.9739 - val_acc: 0.5149
  13. Epoch 7/10
  14. 782/782 [==============================] - 1109s 1s/step - loss: 0.9553 - acc: 0.7510 - val_loss: 1.3359 - val_acc: 0.6366
  15. Epoch 8/10
  16. 782/782 [==============================] - 1120s 1s/step - loss: 0.9221 - acc: 0.7681 - val_loss: 1.3839 - val_acc: 0.6401
  17. Epoch 9/10
  18. 782/782 [==============================] - 1129s 1s/step - loss: 0.8882 - acc: 0.7852 - val_loss: 1.1889 - val_acc: 0.6920
  19. Epoch 10/10
  20. 782/782 [==============================] - 1137s 1s/step - loss: 0.8584 - acc: 0.8003 - val_loss: 1.3718 - val_acc: 0.6465

由于电脑不咋地,所以一些参数没有优化,你如正则化,epochs大小,batch_size等等。

如有错误,欢迎指正‘‘‘‘‘’’’’’

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号