当前位置:   article > 正文

使用MNIST数据集训练手写数字识别模型

mnist数据集

一、MNIST数据集介绍
MNIST 数据集(手写数字数据集)是一个公开的公共数据集,任何人都可以免费获取它。目前,它已经是一个作为机器学习入门的通用性特别强的数据集之一,所以对于想要学习机器学习分类的、深度神经网络分类的、图像识别与处理的小伙伴,都可以选择MNIST数据集入门。

二、MNIST数据集结构
MNIST 数据集包含70000(60000+10000)个样本,其中有60000个训练样本和10000个测试样本,每个样本的像素大小为28*28。

1.MNIST数据集下载方式


方法一
下载地址:http://yann.lecun.com/exdb/mnist/

可以直接下载这四个文件,这四个文件分别为:
①训练样本的图像(60000个)
②对应训练样本上每一张图像上数字的标签(0~9)(60000个)
③测试样本的图像(10000个)
④对应测试样本上每一张图像上数字的标签(0~9)(10000个)

方法二
在Keras中已经内置了多种公共数据集,其中就包含MNIST数据集,如图所示。

所以可以直接调用 tf.keras.datasets.mnist,直接下载数据集。


2.开始训练

可以跟着一步一步做,不会出错

(1)导包

  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import time

(2)打印开始时间

  1. print('--------------')
  2. nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
  3. print(nowtime)

 效果展示:

(3)预处理

  1. #初始化
  2. plt.rcParams['font.sans-serif'] = ['SimHei']
  3. #加载数据
  4. mnist = tf.keras.datasets.mnist
  5. (train_x,train_y),(test_x,test_y) = mnist.load_data()
  6. print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))
  7. #数据预处理
  8. #X_train = train_x.reshape((60000,28*28))
  9. #Y_train = train_y.reshape((60000,28*28)) #后面采用tf.keras.layers.Flatten()改变数组形状
  10. X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) #归一化
  11. y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

效果展示:

(4)建立模型查看结构

  1. # 建立模型
  2. model = tf.keras.Sequential([
  3. tf.keras.layers.Flatten(input_shape=(28, 28)),
  4. tf.keras.layers.Dense(128, activation='relu'),
  5. tf.keras.layers.Dense(10, activation='softmax')
  6. ])
  7. print('\n',model.summary()) #查看网络结构和参数信息

    效果展示:       

(5)开始训练

  1. #配置模型训练方法
  2. #adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
  3. model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])
  4. #训练模型
  5. #批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
  6. print('--------------')
  7. nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
  8. print('训练前时刻:'+str(nowtime))
  9. history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)
  10. print('--------------')
  11. nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
  12. print('训练后时刻:'+str(nowtime))

效果展示:

(6)评估模型

  1. #评估模型
  2. model.evaluate(X_test,y_test,verbose=2) #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力

效果展示:

(7)结果可视化

  1. #结果可视化
  2. print(history.history)
  3. loss = history.history['loss'] #训练集损失
  4. val_loss = history.history['val_loss'] #测试集损失
  5. acc = history.history['sparse_categorical_accuracy'] #训练集准确率
  6. val_acc = history.history['val_sparse_categorical_accuracy'] #测试集准确率
  7. plt.figure(figsize=(10,3))
  8. plt.subplot(121)
  9. plt.plot(loss,color='b',label='train')
  10. plt.plot(val_loss,color='r',label='test')
  11. plt.ylabel('loss')
  12. plt.legend()
  13. plt.subplot(122)
  14. plt.plot(acc,color='b',label='train')
  15. plt.plot(val_acc,color='r',label='test')
  16. plt.ylabel('Accuracy')
  17. plt.legend()
  18. #暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
  19. #根据需要自行选择
  20. #plt.ion() #打开交互式操作模式
  21. #plt.show()
  22. #plt.pause(5)
  23. #plt.close()
  24. #使用模型
  25. plt.figure()
  26. for i in range(10):
  27. num = np.random.randint(1,10000)
  28. plt.subplot(2,5,i+1)
  29. plt.axis('off')
  30. plt.imshow(test_x[num],cmap='gray')
  31. demo = tf.reshape(X_test[num],(1,28,28))
  32. y_pred = np.argmax(model.predict(demo))
  33. plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
  34. #y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
  35. #print('X_test[0:5]: %s'%(X_test[0:5].shape))
  36. #print('y_pred: %s'%(y_pred))
  37. #plt.ion() #打开交互式操作模式
  38. plt.show()
  39. #plt.pause(5)
  40. #plt.close()

展示效果:

3.测试模型

1.修改测试图片的路径

2.修改保存模型的路径

  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import cv2
  5. # 建立模型
  6. model = tf.keras.Sequential()
  7. model.add(tf.keras.layers.Flatten(input_shape=(28,28))) # 添加Flatten层说明输入数据的形状
  8. model.add(tf.keras.layers.Dense(128, activation='relu')) # 添加隐含层,为全连接层,128个节点,relu激活函数
  9. model.add(tf.keras.layers.Dense(10, activation='softmax')) # 添加输出层,为全连接层,10个节点,softmax激活函数
  10. # 加载模型参数
  11. model.load_weights('mnist_weights.h5') # 路径根据文件实际位置修改,不然会报错
  12. # 定义一个函数来预处理图片
  13. def preprocess_image(image_path):
  14. # 读取图片,转换为灰度图
  15. img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
  16. if img is None:
  17. print(f"无法加载图片:{image_path}")
  18. return None
  19. # 调整图片大小为28x28像素
  20. img = cv2.resize(img, (28, 28))
  21. # 归一化图片像素值到0-1范围
  22. img = img / 255.0
  23. # 转换图片形状以匹配模型输入
  24. img = img.reshape(1, 28, 28)
  25. return img
  26. # 使用模型进行预测
  27. plt.figure()
  28. # 这里替换为你的图片路径列表
  29. image_paths = ['shouxieti_img/test/0_10.jpg']
  30. for i, image_path in enumerate(image_paths):
  31. # 预处理图片
  32. img = preprocess_image(image_path)
  33. if img is not None:
  34. # 使用模型进行预测
  35. y_pred = np.argmax(model.predict(img))
  36. # 显示图片和预测结果
  37. plt.subplot(1, len(image_paths), i+1)
  38. plt.imshow(img[0], cmap='gray')
  39. plt.axis('off')
  40. plt.title('预测值:' + str(y_pred))
  41. plt.show()

 效果展示:

话不多说 源码奉上!

 4.全部代码

  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import time
  5. print('--------------')
  6. nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
  7. print(nowtime)
  8. #初始化
  9. plt.rcParams['font.sans-serif'] = ['SimHei']
  10. #加载数据
  11. mnist = tf.keras.datasets.mnist
  12. (train_x,train_y),(test_x,test_y) = mnist.load_data()
  13. print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))
  14. #数据预处理
  15. #X_train = train_x.reshape((60000,28*28))
  16. #Y_train = train_y.reshape((60000,28*28)) #后面采用tf.keras.layers.Flatten()改变数组形状
  17. X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) #归一化
  18. y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)
  19. # 建立模型
  20. model = tf.keras.Sequential([
  21. tf.keras.layers.Flatten(input_shape=(28, 28)),
  22. tf.keras.layers.Dense(128, activation='relu'),
  23. tf.keras.layers.Dense(10, activation='softmax')
  24. ])
  25. print('\n',model.summary()) #查看网络结构和参数信息
  26. #配置模型训练方法
  27. #adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
  28. model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])
  29. #训练模型
  30. #批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
  31. print('--------------')
  32. nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
  33. print('训练前时刻:'+str(nowtime))
  34. history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)
  35. print('--------------')
  36. nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
  37. print('训练后时刻:'+str(nowtime))
  38. #评估模型
  39. model.evaluate(X_test,y_test,verbose=2) #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力
  40. #结果可视化
  41. print(history.history)
  42. loss = history.history['loss'] #训练集损失
  43. val_loss = history.history['val_loss'] #测试集损失
  44. acc = history.history['sparse_categorical_accuracy'] #训练集准确率
  45. val_acc = history.history['val_sparse_categorical_accuracy'] #测试集准确率
  46. plt.figure(figsize=(10,3))
  47. plt.subplot(121)
  48. plt.plot(loss,color='b',label='train')
  49. plt.plot(val_loss,color='r',label='test')
  50. plt.ylabel('loss')
  51. plt.legend()
  52. plt.subplot(122)
  53. plt.plot(acc,color='b',label='train')
  54. plt.plot(val_acc,color='r',label='test')
  55. plt.ylabel('Accuracy')
  56. plt.legend()
  57. #暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
  58. #根据需要自行选择
  59. #plt.ion() #打开交互式操作模式
  60. #plt.show()
  61. #plt.pause(5)
  62. #plt.close()
  63. #使用模型
  64. plt.figure()
  65. for i in range(10):
  66. num = np.random.randint(1,10000)
  67. plt.subplot(2,5,i+1)
  68. plt.axis('off')
  69. plt.imshow(test_x[num],cmap='gray')
  70. demo = tf.reshape(X_test[num],(1,28,28))
  71. y_pred = np.argmax(model.predict(demo))
  72. plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
  73. #y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
  74. #print('X_test[0:5]: %s'%(X_test[0:5].shape))
  75. #print('y_pred: %s'%(y_pred))
  76. #plt.ion() #打开交互式操作模式
  77. plt.show()
  78. #plt.pause(5)
  79. #plt.close()

 转载于:【神经网络与深度学习】使用MNIST数据集训练手写数字识别模型——[附完整训练代码]_使用mnist数据集进行模型训练时-CSDN博客

 谢谢支持!

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号