当前位置:   article > 正文

Rnn-demo_rnn demo

rnn demo
 x是28*28的手写数字图片,y是该图片对应的数字   

 

  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. mnist=input_data.read_data_sets('D:/tensorflow/MNIST_data/mnist',one_hot=True)
  4. batch_size=100
  5. n_batch=mnist.train.num_examples//batch_size
  6. max_time=28
  7. n_inputs=28
  8. lstm_size=100
  9. n_class=10
  10. def Weights_variables():
  11. initial=tf.truncated_normal([lstm_size,n_class],stddev=0.1)
  12. return tf.Variable(initial)
  13. def biases_variables():
  14. initial=tf.constant(0.1,shape=[n_class])
  15. return tf.Variable(initial)
  16. with tf.name_scope('RNN'):
  17. def RNN(X):
  18. inputs=tf.reshape(X,[-1,max_time,n_inputs])
  19. lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
  20. output,final_output=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
  21. return final_output[1]
  22. with tf.name_scope('input'):
  23. x=tf.placeholder(tf.float32,[None,784])
  24. y=tf.placeholder(tf.float32,[None,10])
  25. hiddlen=RNN(x)
  26. with tf.name_scope('W'):
  27. w=Weights_variables()
  28. with tf.name_scope('b'):
  29. b=biases_variables()
  30. with tf.name_scope('prediction'):
  31. prediction=tf.nn.softmax(tf.matmul(hiddlen,w)+b)
  32. with tf.name_scope('loss'):
  33. loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
  34. tf.summary.scalar('loss',loss)
  35. with tf.name_scope('train_step'):
  36. train_step=tf.train.AdamOptimizer(0.0001).minimize(loss)
  37. with tf.name_scope('Accuracy'):
  38. prediction_value=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
  39. Accuracy=tf.reduce_mean(tf.cast(prediction_value,dtype=tf.float32))
  40. tf.summary.scalar('Accuracy',Accuracy)
  41. init=tf.global_variables_initializer()
  42. merged=tf.summary.merge_all()
  43. with tf.Session() as sess:
  44. sess.run(init)
  45. writer_train=tf.summary.FileWriter('logs/train',sess.graph)
  46. writer_test=tf.summary.FileWriter('logs/test',sess.graph)
  47. for batch in range(n_batch-250):
  48. batch_xs,batch_ys=mnist.train.next_batch(batch_size)
  49. sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
  50. summary=sess.run(merged,feed_dict={x:batch_xs,y:batch_ys})
  51. writer_train.add_summary(summary,batch)
  52. batch_xs,batch_ys=mnist.test.next_batch(batch_size)
  53. summary=sess.run(merged,feed_dict={x:batch_xs,y:batch_ys})
  54. writer_test.add_summary(summary,batch)

 

 

其中在RNN输出又加了softmax来预测0~9这10个数字中各种的概率,如下:

 

可以看到准确率呈现上升趋势,因为demo中训练次数较少,而且为了快点出结果,这里还特意减少了250个batch即

range(n_batch-250):

所以结果不是很好

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/978332
推荐阅读
相关标签
  

闽ICP备14008679号