当前位置:   article > 正文

卷积神经网络——vgg16网络及其python实现_python vgg16

python vgg16

1、介绍     

        VGG-16网络包括13个卷积层和3个全连接层,网络结构较LeNet-5等网络变得十分复杂,但同时也有不错的效果。VGG16有强大的拟合能力在当时取得了非常的效果,但同时VGG也有部分不足:
1、巨大参数量导致训练时间过长,调参难度较大;
2、模型所需内存容量大,VGG的权值文件很大,用到实际应用会比较困难。

2、结构原理 

这是经典的vgg网络,输入图片大小为224*224。

下面这为官方给出的几种VGG结构图。

 

 现在多用的为D模型。

简单介绍下过程,输入224*224大小的图片,然后用两次64个3*3的卷积核进行全采集,也就是补零采集,保证特征不丢失,得到64*224*224的特征;池化层得到64*112*112;再利用128个3*3的卷积核进行特征采集两次,得到特征112*112*128;池化得到56*56*128大小特征.........反复这样操作,最后卷积完得到7*7*512的特征,然后利用全连接层进行展开,最后得到1000个特征,随后进行概率分类操作。

3、python实现

        选用的数据集为fashion数据集,具体请另外了解。数据可直接在库中导入,本文用class网络编写神经网络程序。

  1. class VGG16(Model):
  2. def __init__(self):
  3. super(VGG16, self).__init__()
  4. self.c1 = Conv2D(filters=64, kernel_size=(3, 3), padding='same')
  5. self.b1 = BatchNormalization()
  6. self.a1 = Activation('relu')
  7. self.c2 = Conv2D(filters=64, kernel_size=(3, 3), padding='same', )
  8. self.b2 = BatchNormalization()
  9. self.a2 = Activation('relu')
  10. self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  11. self.d1 = Dropout(0.2)
  12. self.c3 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
  13. self.b3 = BatchNormalization()
  14. self.a3 = Activation('relu')
  15. self.c4 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
  16. self.b4 = BatchNormalization()
  17. self.a4 = Activation('relu')
  18. self.p2 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  19. self.d2 = Dropout(0.2)
  20. self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
  21. self.b5 = BatchNormalization()
  22. self.a5 = Activation('relu')
  23. self.c6 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
  24. self.b6 = BatchNormalization()
  25. self.a6 = Activation('relu')
  26. self.c7 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
  27. self.b7 = BatchNormalization()
  28. self.a7 = Activation('relu')
  29. self.p3 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  30. self.d3 = Dropout(0.2)
  31. self.c8 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  32. self.b8 = BatchNormalization()
  33. self.a8 = Activation('relu')
  34. self.c9 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  35. self.b9 = BatchNormalization()
  36. self.a9 = Activation('relu')
  37. self.c10 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  38. self.b10 = BatchNormalization()
  39. self.a10 = Activation('relu')
  40. self.p4 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  41. self.d4 = Dropout(0.2)
  42. self.c11 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  43. self.b11 = BatchNormalization()
  44. self.a11 = Activation('relu')
  45. self.c12 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  46. self.b12 = BatchNormalization()
  47. self.a12 = Activation('relu')
  48. self.c13 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
  49. self.b13 = BatchNormalization()
  50. self.a13 = Activation('relu')
  51. self.p5 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
  52. self.d5 = Dropout(0.2)
  53. self.flatten = Flatten()
  54. self.f1 = Dense(512, activation='relu')
  55. self.d6 = Dropout(0.2)
  56. self.f2 = Dense(512, activation='relu')
  57. self.d7 = Dropout(0.2)
  58. self.f3 = Dense(10, activation='softmax')
  59. def call(self, x):
  60. x = self.c1(x)
  61. x = self.b1(x)
  62. x = self.a1(x)
  63. x = self.c2(x)
  64. x = self.b2(x)
  65. x = self.a2(x)
  66. x = self.p1(x)
  67. x = self.d1(x)
  68. x = self.c3(x)
  69. x = self.b3(x)
  70. x = self.a3(x)
  71. x = self.c4(x)
  72. x = self.b4(x)
  73. x = self.a4(x)
  74. x = self.p2(x)
  75. x = self.d2(x)
  76. x = self.c5(x)
  77. x = self.b5(x)
  78. x = self.a5(x)
  79. x = self.c6(x)
  80. x = self.b6(x)
  81. x = self.a6(x)
  82. x = self.c7(x)
  83. x = self.b7(x)
  84. x = self.a7(x)
  85. x = self.p3(x)
  86. x = self.d3(x)
  87. x = self.c8(x)
  88. x = self.b8(x)
  89. x = self.a8(x)
  90. x = self.c9(x)
  91. x = self.b9(x)
  92. x = self.a9(x)
  93. x = self.c10(x)
  94. x = self.b10(x)
  95. x = self.a10(x)
  96. x = self.p4(x)
  97. x = self.d4(x)
  98. x = self.c11(x)
  99. x = self.b11(x)
  100. x = self.a11(x)
  101. x = self.c12(x)
  102. x = self.b12(x)
  103. x = self.a12(x)
  104. x = self.c13(x)
  105. x = self.b13(x)
  106. x = self.a13(x)
  107. x = self.p5(x)
  108. x = self.d5(x)
  109. x = self.flatten(x)
  110. x = self.f1(x)
  111. x = self.d6(x)
  112. x = self.f2(x)
  113. x = self.d7(x)
  114. y = self.f3(x)
  115. return y
  116. model = VGG16()

读取数据

  1. import tensorflow as tf
  2. import os
  3. import numpy as np
  4. from matplotlib import pyplot as plt
  5. from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
  6. from tensorflow.keras import Model
  7. np.set_printoptions(threshold=np.inf)
  8. fashion = tf.keras.datasets.fashion_mnist
  9. (x_train, y_train), (x_test, y_test) = fashion.load_data()
  10. x_train, x_test = x_train / 255.0, x_test / 255.0
  11. x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
  12. x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

迭代训练

  1. model.compile(optimizer='adam',
  2. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
  3. metrics=['sparse_categorical_accuracy'])
  4. cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
  5. save_weights_only=True,
  6. save_best_only=True)
  7. history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1,
  8. callbacks=[cp_callback])

绘制结果图

  1. acc = history.history['sparse_categorical_accuracy']
  2. val_acc = history.history['val_sparse_categorical_accuracy']
  3. loss = history.history['loss']
  4. val_loss = history.history['val_loss']
  5. plt.subplot(1, 2, 1)
  6. plt.plot(acc, label='Training Accuracy')
  7. plt.plot(val_acc, label='Validation Accuracy')
  8. plt.title('Training and Validation Accuracy')
  9. plt.legend()
  10. plt.subplot(1, 2, 2)
  11. plt.plot(loss, label='Training Loss')
  12. plt.plot(val_loss, label='Validation Loss')
  13. plt.title('Training and Validation Loss')
  14. plt.legend()
  15. plt.show()

 

 虽然不是很稳定,但总的来说准确率还可以。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/496865
推荐阅读
相关标签
  

闽ICP备14008679号