赞
踩
这里尝试采用VGG网络对CIFAR-10数据集进行分类识别。
1 导入需要的模块
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import models, layers
import matplotlib.pyplot as plt
2 载入CIFAR-10数据集
# load CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
# train_images: 50000*32*32*3, train_labels: 50000*1, test_images: 10000*32*32*3, test_labels: 10000*1
# change data shape & types
train_input = train_images/255.0
test_input = test_images/255.0
train_output = train_labels
test_output = test_labels
3 构建神经网络
首先,定义构建模型函数
def build_model(): model = models.Sequential() # 1st layer, input shape (32,32,3) model.add(layers.Conv2D(64, (3,3), padding='same', input_shape=(32,32,3))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.Dropout(0.3)) # 2nd layer, input shape (32,32,64) model.add(layers.Conv2D(64, (3,3), padding='same')) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')) # 3rd layer, (16,16,64) model.add(layers.Conv2D(128, (3,3), padding='same')) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.Dropout(0.4)) # 4th layer, (16,16,128) model.add(layers.Conv2D(128, (3,3), padding='same')) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D(pool_size=(2,2))) # 5th layer, (8,8,128) model.add(layers.Conv2D(256, (3, 3), padding='same')) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.Dropout(0.4)) # 6th layer, (8,8,256) model.add(layers.Flatten()) model.add(layers.Dense(512)) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) #7th layer, 512 model.add(layers.Dropout(0.5)) model.add(layers.Dense(10)) model.add(layers.Activation('softmax')) model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy']) return model
这里构建5层卷积层,加上2层全连接层。调用函数
# build model
network = build_model()
# show network summary
network.summary()
显示结果如下
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 32, 32, 64) 1792 _________________________________________________________________ activation (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ batch_normalization (BatchNo (None, 32, 32, 64) 256 _________________________________________________________________ dropout (Dropout) (None, 32, 32, 64) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ activation_1 (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ batch_normalization_1 (Batch (None, 32, 32, 64) 256 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 16, 16, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 16, 16, 128) 73856 _________________________________________________________________ activation_2 (Activation) (None, 16, 16, 128) 0 _________________________________________________________________ batch_normalization_2 (Batch (None, 16, 16, 128) 512 _________________________________________________________________ dropout_1 (Dropout) (None, 16, 16, 128) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 16, 16, 128) 147584 _________________________________________________________________ activation_3 (Activation) (None, 16, 16, 128) 0 _________________________________________________________________ batch_normalization_3 (Batch (None, 16, 16, 128) 512 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 8, 8, 128) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 8, 8, 256) 295168 _________________________________________________________________ activation_4 (Activation) (None, 8, 8, 256) 0 _________________________________________________________________ batch_normalization_4 (Batch (None, 8, 8, 256) 1024 _________________________________________________________________ dropout_2 (Dropout) (None, 8, 8, 256) 0 _________________________________________________________________ flatten (Flatten) (None, 16384) 0 _________________________________________________________________ dense (Dense) (None, 512) 8389120 _________________________________________________________________ activation_5 (Activation) (None, 512) 0 _________________________________________________________________ batch_normalization_5 (Batch (None, 512) 2048 _________________________________________________________________ dropout_3 (Dropout) (None, 512) 0 _________________________________________________________________ dense_1 (Dense) (None, 10) 5130 _________________________________________________________________ activation_6 (Activation) (None, 10) 0 ================================================================= Total params: 8,954,186 Trainable params: 8,951,882 Non-trainable params: 2,304 _________________________________________________________________
4 训练模型
调用函数训练
# train model
history = network.fit(train_input, train_output, epochs=30, batch_size=256, validation_split=0.1)
训练30次,batch_size为256,训练结果显示如下
Epoch 1/30 176/176 [==============================] - 15s 62ms/step - loss: 1.7285 - sparse_categorical_accuracy: 0.4502 - val_loss: 4.3244 - val_sparse_categorical_accuracy: 0.1566 Epoch 2/30 176/176 [==============================] - 9s 53ms/step - loss: 1.0967 - sparse_categorical_accuracy: 0.6184 - val_loss: 4.2188 - val_sparse_categorical_accuracy: 0.2258 Epoch 3/30 176/176 [==============================] - 9s 53ms/step - loss: 0.8392 - sparse_categorical_accuracy: 0.7048 - val_loss: 1.5962 - val_sparse_categorical_accuracy: 0.5262 Epoch 4/30 176/176 [==============================] - 9s 53ms/step - loss: 0.6949 - sparse_categorical_accuracy: 0.7549 - val_loss: 0.7939 - val_sparse_categorical_accuracy: 0.7430 Epoch 5/30 176/176 [==============================] - 9s 53ms/step - loss: 0.5994 - sparse_categorical_accuracy: 0.7880 - val_loss: 0.9743 - val_sparse_categorical_accuracy: 0.7168 Epoch 6/30 176/176 [==============================] - 9s 54ms/step - loss: 0.5187 - sparse_categorical_accuracy: 0.8173 - val_loss: 0.8175 - val_sparse_categorical_accuracy: 0.7520 Epoch 7/30 176/176 [==============================] - 9s 54ms/step - loss: 0.4544 - sparse_categorical_accuracy: 0.8398 - val_loss: 0.8302 - val_sparse_categorical_accuracy: 0.7544 Epoch 8/30 176/176 [==============================] - 9s 54ms/step - loss: 0.3901 - sparse_categorical_accuracy: 0.8629 - val_loss: 0.7184 - val_sparse_categorical_accuracy: 0.7934 Epoch 9/30 176/176 [==============================] - 9s 54ms/step - loss: 0.3390 - sparse_categorical_accuracy: 0.8794 - val_loss: 0.8141 - val_sparse_categorical_accuracy: 0.7962 Epoch 10/30 176/176 [==============================] - 10s 54ms/step - loss: 0.2886 - sparse_categorical_accuracy: 0.8964 - val_loss: 0.9829 - val_sparse_categorical_accuracy: 0.7804 Epoch 11/30 176/176 [==============================] - 10s 54ms/step - loss: 0.2630 - sparse_categorical_accuracy: 0.9075 - val_loss: 0.7088 - val_sparse_categorical_accuracy: 0.8034 Epoch 12/30 176/176 [==============================] - 10s 54ms/step - loss: 0.2362 - sparse_categorical_accuracy: 0.9164 - val_loss: 0.5813 - val_sparse_categorical_accuracy: 0.8336 Epoch 13/30 176/176 [==============================] - 10s 54ms/step - loss: 0.2086 - sparse_categorical_accuracy: 0.9269 - val_loss: 0.7702 - val_sparse_categorical_accuracy: 0.8014 Epoch 14/30 176/176 [==============================] - 10s 54ms/step - loss: 0.1860 - sparse_categorical_accuracy: 0.9345 - val_loss: 0.7444 - val_sparse_categorical_accuracy: 0.8254 Epoch 15/30 176/176 [==============================] - 10s 54ms/step - loss: 0.1748 - sparse_categorical_accuracy: 0.9398 - val_loss: 0.7130 - val_sparse_categorical_accuracy: 0.8184 Epoch 16/30 176/176 [==============================] - 10s 54ms/step - loss: 0.1582 - sparse_categorical_accuracy: 0.9443 - val_loss: 0.7712 - val_sparse_categorical_accuracy: 0.8226 Epoch 17/30 176/176 [==============================] - 10s 55ms/step - loss: 0.1459 - sparse_categorical_accuracy: 0.9488 - val_loss: 0.8808 - val_sparse_categorical_accuracy: 0.8086 Epoch 18/30 176/176 [==============================] - 10s 54ms/step - loss: 0.1329 - sparse_categorical_accuracy: 0.9530 - val_loss: 0.7062 - val_sparse_categorical_accuracy: 0.8340 Epoch 19/30 176/176 [==============================] - 10s 54ms/step - loss: 0.1323 - sparse_categorical_accuracy: 0.9538 - val_loss: 0.6216 - val_sparse_categorical_accuracy: 0.8380 Epoch 20/30 176/176 [==============================] - 10s 55ms/step - loss: 0.1243 - sparse_categorical_accuracy: 0.9575 - val_loss: 0.6749 - val_sparse_categorical_accuracy: 0.8334 Epoch 21/30 176/176 [==============================] - 10s 54ms/step - loss: 0.1206 - sparse_categorical_accuracy: 0.9586 - val_loss: 0.7408 - val_sparse_categorical_accuracy: 0.8268 Epoch 22/30 176/176 [==============================] - 10s 54ms/step - loss: 0.1105 - sparse_categorical_accuracy: 0.9615 - val_loss: 0.7999 - val_sparse_categorical_accuracy: 0.8314 Epoch 23/30 176/176 [==============================] - 10s 55ms/step - loss: 0.1064 - sparse_categorical_accuracy: 0.9633 - val_loss: 0.6867 - val_sparse_categorical_accuracy: 0.8396 Epoch 24/30 176/176 [==============================] - 10s 54ms/step - loss: 0.0974 - sparse_categorical_accuracy: 0.9655 - val_loss: 0.6695 - val_sparse_categorical_accuracy: 0.8422 Epoch 25/30 176/176 [==============================] - 10s 54ms/step - loss: 0.0908 - sparse_categorical_accuracy: 0.9687 - val_loss: 0.7222 - val_sparse_categorical_accuracy: 0.8306 Epoch 26/30 176/176 [==============================] - 10s 54ms/step - loss: 0.0907 - sparse_categorical_accuracy: 0.9689 - val_loss: 0.6841 - val_sparse_categorical_accuracy: 0.8384 Epoch 27/30 176/176 [==============================] - 10s 55ms/step - loss: 0.0866 - sparse_categorical_accuracy: 0.9696 - val_loss: 0.8356 - val_sparse_categorical_accuracy: 0.8286 Epoch 28/30 176/176 [==============================] - 10s 55ms/step - loss: 0.0898 - sparse_categorical_accuracy: 0.9690 - val_loss: 0.6899 - val_sparse_categorical_accuracy: 0.8392 Epoch 29/30 176/176 [==============================] - 10s 55ms/step - loss: 0.0867 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.7572 - val_sparse_categorical_accuracy: 0.8338 Epoch 30/30 176/176 [==============================] - 10s 55ms/step - loss: 0.0796 - sparse_categorical_accuracy: 0.9728 - val_loss: 0.7699 - val_sparse_categorical_accuracy: 0.8336
经过训练,得到0.9728的训练精度和0.8336的测试精度。对训练过程进行绘制,如下
# plot train history loss = history.history['loss'] val_loss = history.history['val_loss'] acc = history.history['sparse_categorical_accuracy'] val_acc = history.history['val_sparse_categorical_accuracy'] plt.figure(figsize=(10,3)) plt.subplot(1,2,1) plt.plot(loss, color='blue', label='train') plt.plot(val_loss, color='red', label='test') plt.ylabel('loss') plt.legend() plt.subplot(1,2,2) plt.plot(acc, color='blue', label='train') plt.plot(val_acc, color='red', label='test') plt.ylabel('accuracy') plt.legend()
显示如下的结果
训练和测试的准确度和损失函数,都随着训练次数的增加,逐渐优化。显示在泛化能力上要好于之前的模型。
5 测试训练模型
# evaluate model
network.evaluate(test_input, test_output, verbose=2)
显示在测试集上的准确度和损失函数
313/313 - 1s - loss: 0.7931 - sparse_categorical_accuracy: 0.8315
[0.7930535078048706, 0.8314999938011169]
结果和训练给出的性能指标接近。绘制测试集前100张图片的测试结果
# predict on test data predict_output = network.predict(test_input) # lines and columns of subplots m = 10 n = 10 num = m*n # labels of category labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # figure size plt.figure(figsize=(15,15)) # plot first 100 pictures and results in test images for i in range(num): plt.subplot(m,n,i+1) type_index = np.argmax(predict_output[i]); label = labels[type_index] clr = 'black' if type_index == test_labels[i] else 'red' plt.imshow(test_images[i]) #plt.axis('off') plt.xticks([]) plt.yticks([]) plt.xlabel(label, color=clr) plt.show()
最后图片显示结果如下
红色表示错误的识别结果,基本上和测试给出的准确率指标相当。相比之下,这个类型的神经网络具有较好的识别性能和泛化能力。
参考链接:https://blog.csdn.net/Mind_programmonkey/article/details/121049217
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。