赞
踩
首先读取数据,数据源是mnist库,可以通过input_data中read_data函数直接读取数据,数据图像为28*28。
- #导入库
- import tensorflow as tf
- #下载数据对应的库
- import input_data
- import numpy as np
- import matplotlib.pyplot as plt
- print ("Packages imported")
-
- #导入mnist数据
- mnist = input_data.read_data_sets("data/", one_hot=True)#one_hot=True 表示 数据的标签是one_hot编码的,即数据标签为1*10的数组
- #读取训练数据,训练标签,测试数据,测试标签
- trainimgs, trainlabels, testimgs, testlabels \
- = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
- #获取训练数据个数,测试集数据个数,图像维度和类别数
- ntrain, ntest, dim, nclasses \
- = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
- print ("MNIST loaded")
读取数据后设置参数。本次使用LSTM作为训练模型,因此需要搭建LSTM,因图像为28*28,所以将每一行图像作为一次输入,这样每一次训练,LSTM需要运算28次,设置隐层为128,所以从输入到隐层的全连接参数为28*128个,经过运算后输出与隐层全连接参数为128*10
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。