赞
踩
x是28*28的手写数字图片,y是该图片对应的数字
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
-
- mnist=input_data.read_data_sets('D:/tensorflow/MNIST_data/mnist',one_hot=True)
-
- batch_size=100
- n_batch=mnist.train.num_examples//batch_size
- max_time=28
- n_inputs=28
- lstm_size=100
- n_class=10
-
- def Weights_variables():
- initial=tf.truncated_normal([lstm_size,n_class],stddev=0.1)
- return tf.Variable(initial)
-
- def biases_variables():
- initial=tf.constant(0.1,shape=[n_class])
- return tf.Variable(initial)
-
- with tf.name_scope('RNN'):
- def RNN(X):
- inputs=tf.reshape(X,[-1,max_time,n_inputs])
- lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
- output,final_output=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
- return final_output[1]
-
- with tf.name_scope('input'):
- x=tf.placeholder(tf.float32,[None,784])
- y=tf.placeholder(tf.float32,[None,10])
-
-
- hiddlen=RNN(x)
- with tf.name_scope('W'):
- w=Weights_variables()
- with tf.name_scope('b'):
- b=biases_variables()
- with tf.name_scope('prediction'):
- prediction=tf.nn.softmax(tf.matmul(hiddlen,w)+b)
-
- with tf.name_scope('loss'):
- loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
- tf.summary.scalar('loss',loss)
- with tf.name_scope('train_step'):
- train_step=tf.train.AdamOptimizer(0.0001).minimize(loss)
-
- with tf.name_scope('Accuracy'):
- prediction_value=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
- Accuracy=tf.reduce_mean(tf.cast(prediction_value,dtype=tf.float32))
- tf.summary.scalar('Accuracy',Accuracy)
-
-
- init=tf.global_variables_initializer()
- merged=tf.summary.merge_all()
-
- with tf.Session() as sess:
- sess.run(init)
- writer_train=tf.summary.FileWriter('logs/train',sess.graph)
- writer_test=tf.summary.FileWriter('logs/test',sess.graph)
- for batch in range(n_batch-250):
- batch_xs,batch_ys=mnist.train.next_batch(batch_size)
- sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
- summary=sess.run(merged,feed_dict={x:batch_xs,y:batch_ys})
- writer_train.add_summary(summary,batch)
-
- batch_xs,batch_ys=mnist.test.next_batch(batch_size)
- summary=sess.run(merged,feed_dict={x:batch_xs,y:batch_ys})
- writer_test.add_summary(summary,batch)
其中在RNN输出又加了softmax来预测0~9这10个数字中各种的概率,如下:
可以看到准确率呈现上升趋势,因为demo中训练次数较少,而且为了快点出结果,这里还特意减少了250个batch即
range(n_batch-250):
所以结果不是很好
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。