当前位置:   article > 正文

基于LSTM实现mnist手写数字识别_mnist lstm

mnist lstm

首先读取数据,数据源是mnist库,可以通过input_data中read_data函数直接读取数据,数据图像为28*28。

  1. #导入库
  2. import tensorflow as tf
  3. #下载数据对应的库
  4. import input_data
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. print ("Packages imported")
  8. #导入mnist数据
  9. mnist = input_data.read_data_sets("data/", one_hot=True)#one_hot=True 表示 数据的标签是one_hot编码的,即数据标签为1*10的数组
  10. #读取训练数据,训练标签,测试数据,测试标签
  11. trainimgs, trainlabels, testimgs, testlabels \
  12. = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
  13. #获取训练数据个数,测试集数据个数,图像维度和类别数
  14. ntrain, ntest, dim, nclasses \
  15. = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
  16. print ("MNIST loaded")

读取数据后设置参数。本次使用LSTM作为训练模型,因此需要搭建LSTM,因图像为28*28,所以将每一行图像作为一次输入,这样每一次训练,LSTM需要运算28次,设置隐层为128,所以从输入到隐层的全连接参数为28*128个,经过运算后输出与隐层全连接参数为128*10

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/785898
推荐阅读
相关标签
  

闽ICP备14008679号