赞
踩
- #LSTM主要用于语音、文本序列化的问题,但同样可以用于图片分类
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
-
- #载入数据集
- mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)
-
- #输入图片是28*28
- n_inputs=28#输入一行,一行有28个数据
- max_time=28#一共28行
- lstm_size=100#隐藏单元
- n_classes=10#10个分类
- batch_size=50#每批次50个样本
- n_batch=mnist.train.num_examples//batch_size#计算一共多少个批次
-
- #这里的none表示一个维度可以是任一长度
- x=tf.placeholder(tf.float32,[None,784])
- #标签
- y=tf.placeholder(tf.float32,[None,10])
-
- #初始化权值
- weights=tf.Variable(tf.truncated_normal([lstm_size,n_classes],stddev=0.1))
- #初始化偏置值
- biases=tf.Variable(tf.constant(0.1,shape=[n_classes]))
-
- #定义RNN网络
- def RNN(X,weights,biases):
- #input=[batch_size,max_time,n_inputs]
- inputs=tf.reshape(X,[-1,max_time,n_inputs])
- #定义LSTM基本CELL
- #lstm_cell=tf.contrib.
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。