赞
踩
epoch = 10
acc = 89.96%
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 14 2020
@author: jiollos
"""
# 导入包
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model
# 设置显示格式
np.set_printoptions(threshold=np.inf)
# 导入数据集
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
# 归一化处理
x_train, x_test = x_train / 255.0, x_test / 255.0
print("x_train.shape", x_train.shape)
# 给数据增加一个维度,使数据和网络结构匹配
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
# print("x_train.shape", x_train.shape)
# 构建AlexNet的class
class AlexNet8(Model):
def __init__(self):
super(AlexNet8, self).__init__()
# 依次建立卷积层/标准化层/激活层/池化层,一共8层
# 96通道,卷积核为3*3
self.c1 = Conv2D(filters=96, kernel_size=(3, 3))
self.b1 = BatchNormalization()
# ReLu激活函数
self.a1 = Activation('relu')
# 最大池化3*3池化层,步长为2
self.p1 = MaxPool2D(pool_size=(3, 3), strides=2)
self.c2 = Conv2D(filters=256, kernel_size=(3, 3))
self.b2 = BatchNormalization()
self.a2 = Activation('relu')
self.p2 = MaxPool2D(pool_size=(3, 3), strides=2)
# 384个通道,3*3卷积核,使用全零填充,激活函数为ReLu
self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding='same',
activation='relu')
self.c4 = Conv2D(filters=384, kernel_size=(3, 3), padding='same',
activation='relu')
self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same',
activation='relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2)
# 将像素拉直
self.flatten = Flatten()
# 设置全连接层2048个神经元,激活函数为ReLu
self.f1 = Dense(2048, activation='relu')
# 舍弃50%的神经元
self.d1 = Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
# 全连接层,神经元为10,采用softmax激活函数
self.f3 = Dense(10, activation='softmax')
# 前向传播
def call(self, x):
x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x)
x = self.c2(x)
x = self.b2(x)
x = self.a2(x)
x = self.p2(x)
x = self.c3(x)
x = self.c4(x)
x = self.c5(x)
x = self.p3(x)
x = self.flatten(x)
x = self.f1(x)
x = self.d1(x)
x = self.f2(x)
x = self.d2(x)
y = self.f3(x)
return y
# 运行
model = AlexNet8()
# 设置优化器为adam,交叉熵损失函数(预测值是数值,标签是独热码),经过softmax后输出
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
# 设置断点,以保存训练模型
checkpoint_save_path = "./checkpoint/AlexNet8.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
# 执行训练,batch=32,epoch=5,每1个epoch保存1次
history = model.fit(x_train, y_train, batch_size=32, epochs=10, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
# 显示结果
model.summary()
# 保存权重
# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
#plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()
#plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。