当前位置:   article > 正文

基于tensoflow2.x训练mnist数据集神经网络模型优化_tensorflow 2.10 mnist程序

tensorflow 2.10 mnist程序

基于tensoflow2.x训练mnist数据集神经网络模型优化

基于的版本

本次训练的程序是基于python3.8和tensorflow2.0来训练mnist数据集的,前几天有专门写过这方面博客,近几天进行调试时,觉得有必要再进行优化一下,包括识别率和模型储存方面的。具体优化有以下几个方面:

1.对模型的超参数进行了调整
2.网络结构做了部分调整
3.对训练好的模型进行保存和加载

1)超参数调整为:

#设置超参数
learn_rate = 0.005
batch_size = 256
epoch = 1500 
  • 1
  • 2
  • 3

超参数调整过程中成功率趋势大概是这样的,以10次迭代为标准,当然迭代次数变多成功率会变大,但是也会出现过拟合现象,具体参数调整可以自己尝试
在这里插入图片描述
2)网络结构方面,现在的网络结构中加入了Dropout层,解决了过拟合问题:

#3.设计神经网络模型
model=tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(shape=train_x.shape[1:]))
model.add(tf.keras.layers.Dense(300,activation='relu'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(100,activation='relu'))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(10,activation='softmax'))
model.summary()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

3)模型的保存和加载
使用model.save() 和model.load() 方法可以实现对训练好的模型进行保存和加载,保存的模型中包括神经网络结构,权重系数等。具体实现为:

#保存模型
model.save('保存的路径')
  • 1
#加载模型
model  = tf.keras.models.load_model('模型文件所在目录')
  • 1

总的详细代码如下:
1)训练代码并保存模型代码

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

#设置超参数
learn_rate = 0.005
batch_size = 256
epoch = 1500 

#1.读取数据
mnist = tf.keras.datasets.mnist
(train_x,train_y),(test_x,test_y) = mnist.load_data()

#2.数据重组并归一化
train_x = train_x.reshape(60000,28*28)/255.0
test_x = test_x.reshape(10000,28*28)/255.0
#标签编码,变成二进制形式
train_y = tf.one_hot(train_y,depth=10)
test_y = tf.one_hot(test_y,depth=10)

#3.设计神经网络模型
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(shape=train_x.shape[1:]))
model.add(tf.keras.layers.Dense(300,activation='relu'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(100,activation='relu'))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Dense(10,activation='softmax'))
model.summary()

#4.模型超参数设定
model.compile(
    optimizer=tf.keras.optimizers.SGD(0.005),    
    loss=tf.keras.losses.categorical_crossentropy,    
    metrics='accuracy')
    
#5.模型训练
model.fit(
    x = train_x,    
    y = train_y,    
    batch_size = batch_size,    
    epochs = epoch,    
    validation_data=(test_x,test_y)  #训练集预测精度
    )
    
#保存模型
model.save('model/my_model.h5')

#6.模型验证评估
n = 10
predict = model.predict(test_x[:10],n)
fig = plt.figure(figsize=(10,2))
for i in range(n):
    plt.subplot(1,10,i+1)        
    p = test_x[i].reshape(28,28)        
    plt.imshow(p,cmap='gray')    
    plt.axis('off')    
    plt.xticks([])    
    plt.yticks([])
    if tf.argmax(test_y[i])== tf.argmax(predict[i]):             
    	plt.title(str(tf.argmax(test_y[i]).numpy()) + ',' + str(tf.argmax(predict[i]).numpy()),color='green')       else:        plt.title(str(tf.argmax(test_y[i]).numpy()) + ',' + str(tf.argmax(predict[i]).numpy()),color='red')
    	plt.show()
    	
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62

训练的结果:
在这里插入图片描述

2)通过加载训练好的模型进行预测

import tensorflow as tf
import matplotlib.pyplot as plt

#1.加载数据
mnist = tf.keras.datasets.mnist
(train_x,train_y),(test_x,test_y) = mnist.load_data()

#2.对测试数据进行数据处理
test_x = test_x.reshape(10000,28*28)/255.0
test_y = tf.one_hot(test_y,10)

#3.加载训练好的模型
model = tf.keras.models.load_model('model/my_model.h5')

#4.预测并可视化验证,即模型验证评估
n = 10
predict = model.predict(test_x[:10],n)
fig = plt.figure(figsize=(10,2))
for i in range(n):
    plt.subplot(1,10,i+1)        
    p = test_x[i].reshape(28,28)        
    plt.imshow(p,cmap='gray')    
    plt.axis('off')    
    plt.xticks([])    
    plt.yticks([])
    if tf.argmax(test_y[i])== tf.argmax(predict[i]):
            plt.title(str(tf.argmax(test_y[i]).numpy()) + ',' + str(tf.argmax(predict[i]).numpy()),color='green')       
    else:
            plt.title(str(tf.argmax(test_y[i]).numpy()) + ',' + str(tf.argmax(predict[i]).numpy()),color='red')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

加载模型后进行可视化验证的结果(效果比较好,选了测试集前10张做了可视化,哈哈)
在这里插入图片描述

综上,算是对上次的博客的优化吧,,如有不足之处,恳请指出哈。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/350194
推荐阅读
相关标签
  

闽ICP备14008679号