当前位置:   article > 正文

深度学习快速入门项目教程-基于keras的手写数字识别_tesorflow keras 数字识别入门教程

tesorflow keras 数字识别入门教程

前言

        本项目是基于python3.11,keras2.15版本在pycharm中进行的,属于新手必做项目,本文用目前较新的python及其库的版本,演示了具体流程。

导入所需库及工具包

  1. # 导入相关工具包
  2. import numpy as np
  3. # import tensorflow as tf
  4. import keras
  5. import matplotlib.pyplot as plt
  6. # 导入数据集
  7. from keras.datasets import mnist
  8. # 构建序列模型
  9. from keras.models import Sequential
  10. # 导入需要的层
  11. from keras.layers import Dense, Dropout, Activation, BatchNormalization
  12. # 导入辅助工具包
  13. from keras import utils
  14. from keras import regularizers

        numpy库以及matplotlib库在数据分析及可视化中必不可少的,用于后续的数据处理部分,注意在当前版本下,是不需要从tensorflow中导入keras的,而是直接导入keras库。

数据加载

        mnist数据集共有从0-9共10个类别标签,先定义出来,方便使用,数据集的load_data函数自动将数据集分为训练集与测试集,为了对数据集有一个直观印象,可以将数据集中x_train的shape属性打印出来是(60000,28,28),60000表示一共有60000张图片,28表示每张图片的长和宽各有28个像素,并且可以用subplot()函数绘制数据集的前九张图,结果如图1所示,可以在pycharm中的sciview中查看。

  1. # 数据加载
  2. # 定义类别数
  3. num_classes = 10
  4. (x_train , y_train) , (x_test , y_test) = mnist.load_data()
  5. # 打印输出数据集的维度
  6. # print("训练样本初始维度:",x_train.shape)
  7. # print(x_train)
  8. # 数据展示
  9. for i in range(9):
  10. plt.subplot(3,3,i+1)
  11. plt.subplots_adjust(hspace=0.5)
  12. plt.imshow(x_train[i],cmap = 'gray',interpolation = 'none')
  13. plt.title("number{}".format(y_train[i]))
  14. plt.show()

图1 mnist前9张图像

数据预处理

因为模型仅支持向量形式,所以要对数据原来的shape属性进行更改,采用reshape函数,将其变为(60000,784)的向量,并且将向量中的每个元素改为float型,因为图片是灰度图的形式,每个元素的范围是0~255,对其进行规范化处理,每个元素除以255,再将目标值转换为热编码的形式。

  1. # 数据处理
  2. # 格式转换
  3. x_train = x_train.reshape(60000 , 784)
  4. x_test = x_test.reshape(10000 , 784)
  5. x_train = x_train.astype('float32')
  6. x_test = x_test.astype('float32')
  7. # 标准化
  8. x_train /= 255
  9. x_test /= 255
  10. # print("训练样本维度:",x_train.shape)
  11. # 将目标值转换为热编码的形式
  12. y_train = utils.to_categorical(y_train,num_classes)
  13. y_test = utils.to_categorical(y_test,num_classes)
  14. # print("目标值维度:",y_test.shape)
  15. # print(y_test )

模型搭建

模型搭建采用的是Sequential模型,主要包括两个隐层,一个输出层,两个隐层中又分别添加了BN层,激活函数都为ReLu。BN与激活函数的顺序可以调换,再以0.2的概率随机失活一部分神经元,最后输出层选择激活函数为softmax,model.summary()函数可以展示模型结构,如图2所示。

  1. # 模型搭建
  2. model = Sequential()
  3. # 全连接层,两个隐层,一个输出层
  4. # 第一个隐层,512个神经元,先激活后BN,激活函数为RELU,以0.2的概率随机失活
  5. model.add(Dense(512,activation = "relu",input_shape=(784,)))
  6. model.add(BatchNormalization())
  7. model.add(Dropout(0.2))
  8. # 第二个隐层,512个神经元,先BN后激活,激活函数为RELU,以0.2的概率随机失活
  9. model.add(Dense(512,kernel_regularizer=regularizers.l2(0.01)))
  10. model.add(BatchNormalization())
  11. model.add(Activation("relu"))
  12. model.add(Dropout(0.2))
  13. # 输出层,10个神经元,激活函数为softmax
  14. model.add(Dense(10,activation = "softmax"))
  15. #模型展示
  16. print(model.summary())
  17. # utils.plot_model(model,show_shapes=True,to_file="model.svg",dpi=None)

图2 模型结构

模型结构也可以用utils.plot_model()绘制 ,前提是要安装pydot库以及Graphviz,并配置好环境变量,才可以使用。

模型编译与训练

模型编译中需要的设置的参数包括损失函数(交叉熵损失函数),优化器(Adam),以及评价指标(accuracy),模型训练使用model.fit()函数,先输入训练的图像(x_train)及其标签值(y_train),一次性输入网络的数量(batch_size),训练的轮数(epochs),以及分割的测试集,最后用model.save()来保存训练好的模型,与以往不同的一点是保存文件的后缀由.h5改为了.keras

  1. # 模型编译
  2. model.compile(loss='categorical_crossentropy',optimizer='adam',metrics='accuracy')
  3. # 模型训练
  4. history = model.fit(x_train,y_train,batch_size=200,epochs=50,verbose=1,validation_data=(x_test ,y_test))
  5. model.save("my_model_100.keras")

后处理以及模型评估

后处理以图像的形式展示模型训练的结果,以下是绘制损失曲线与准确率曲线的代码。

  1. # 绘制损失函数曲线
  2. plt.figure()
  3. plt.plot(history.history["loss"],label="train_loss" )
  4. plt.plot(history.history["val_loss"],label="val_loss")
  5. plt.legend()
  6. plt.grid()
  7. plt.show()
  1. #绘制准确率图像
  2. plt.figure()
  3. plt.plot(history.history["accuracy"],label="accuracy")
  4. plt.plot(history.history["val_accuracy"],label="val_accuracy")
  5. plt.legend()
  6. plt.grid()
  7. plt.show()

 

图3 损失函数曲线

 

图4 准确率曲线

 模型评估使用model.evaluate()函数,图5为评估结果输出。

score=model.evaluate(x_test,y_test,verbose=1)

 

 图5 模型评估结果

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/751819
推荐阅读
相关标签
  

闽ICP备14008679号