当前位置:   article > 正文

(四)Tensorflow的多层感知机(DNN)模型_tensorflow dnn模型code

tensorflow dnn模型code

对数字识别进行搭建DNN模型,用到了两层隐藏层模型,迭代50次基本效果达到99%+

代码展示:

  1. import tensorflow as tf
  2. import math
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. def multilayer_perception():
  5. #准备数据
  6. mnist = input_data.read_data_sets(r'C:\Users\Administrator\Desktop\AI_project\tensorflow\MNIST_data', one_hot=True)
  7. batch_size = 128
  8. #给数据准备好placeholder
  9. x = tf.placeholder(tf.float32,[batch_size,784],name="x_placeholder")
  10. y = tf.placeholder(tf.float32,[batch_size,10],name="y_placeholder")
  11. #初始化参数,第一个隐藏层,第二个隐藏层,输出层
  12. n_input = 784
  13. n_hidden1,n_hidden2 = 256,256
  14. n_class = 10
  15. weight = {
  16. "h1":tf.Variable(tf.random_normal([n_input,n_hidden1],name="w1")),
  17. "h2":tf.Variable(tf.random_normal([n_hidden1,n_hidden2],name="w2")),
  18. "out":tf.Variable(tf.random_normal([n_hidden2,n_class],name="w"))
  19. }
  20. biase = {
  21. "h1":tf.Variable(tf.zeros([n_hidden1],name="b1")),
  22. "h2":tf.Variable(tf.zeros([n_hidden2],name="b2")),
  23. "out":tf.Variable(tf.zeros([n_class]),name="bias")
  24. }
  25. #构建网络,得到输出层的结果
  26. #第一个隐藏层
  27. layer_1 = tf.add(biase["h1"],tf.matmul(x,weight["h1"]))
  28. layer_1 = tf.nn.relu(layer_1)
  29. #第二个隐藏层
  30. layer_2 = tf.add(biase["h2"],tf.matmul(layer_1,weight["h2"]))
  31. layer_2 = tf.nn.relu(layer_2)
  32. pred = tf.add(tf.matmul(layer_2,weight["out"]),biase["out"])
  33. #构建损失函数和优化器
  34. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=pred))
  35. train = tf.train.AdamOptimizer().minimize(loss)
  36. #模型预测
  37. prediction = tf.equal(tf.argmax(pred,axis=1),tf.argmax(y,axis=1))
  38. accuary = tf.reduce_mean(tf.cast(prediction,tf.float32))
  39. #初始化变量
  40. init = tf.global_variables_initializer()
  41. #定义session来运行计算
  42. with tf.Session() as sess:
  43. sess.run(init)
  44. writer = tf.summary.FileWriter("./logs",sess.graph)
  45. writer.close()
  46. n_batch_train = math.ceil(mnist.train.num_examples/batch_size)
  47. n_batch_test = math.ceil(mnist.test.num_examples/batch_size)
  48. for i in range(50):
  49. # 训练
  50. loss_total = 0
  51. for _ in range(n_batch_train):
  52. x_input, y_input = mnist.train.next_batch(batch_size)
  53. _, l = sess.run([train, loss], feed_dict={x: x_input, y: y_input})
  54. loss_total += l
  55. # 测试
  56. accuary_total = 0
  57. for _ in range(n_batch_test):
  58. x_input, y_input = mnist.train.next_batch(batch_size)
  59. accuary_total += sess.run(accuary, feed_dict={x: x_input, y: y_input})
  60. print("Iteration:{},loss:{},accuary:{}".format(i,loss_total/n_batch_train,accuary_total/n_batch_test))

效果展示:

  1. Iteration:0,loss:174.60768956916277,accuary:0.8553204113924051
  2. Iteration:1,loss:43.34331250301627,accuary:0.8945806962025317
  3. Iteration:2,loss:27.401089508588925,accuary:0.9154469936708861
  4. Iteration:3,loss:19.30703142500201,accuary:0.9252373417721519
  5. Iteration:4,loss:14.60925396680832,accuary:0.939873417721519
  6. Iteration:5,loss:11.22583605861941,accuary:0.9509493670886076
  7. Iteration:6,loss:8.509260208107705,accuary:0.962618670886076
  8. Iteration:7,loss:6.391588538559135,accuary:0.9598496835443038
  9. Iteration:8,loss:4.8431818048346535,accuary:0.9668710443037974
  10. Iteration:9,loss:3.9486298046364,accuary:0.9698378164556962
  11. Iteration:10,loss:2.9593446985000207,accuary:0.977254746835443
  12. Iteration:11,loss:2.2546944803228146,accuary:0.9822982594936709
  13. Iteration:12,loss:1.6660885298288413,accuary:0.9841772151898734
  14. Iteration:13,loss:1.2757584760756577,accuary:0.985067246835443
  15. Iteration:14,loss:1.0731323728177171,accuary:0.9853639240506329
  16. Iteration:15,loss:0.9020052603093415,accuary:0.9900118670886076
  17. Iteration:16,loss:0.6498902476481249,accuary:0.9878362341772152
  18. Iteration:17,loss:0.4907471142712625,accuary:0.9933742088607594
  19. Iteration:18,loss:0.5058308260880081,accuary:0.9913963607594937
  20. Iteration:19,loss:0.43065127376118356,accuary:0.9935719936708861
  21. Iteration:20,loss:0.4260103874590171,accuary:0.9930775316455697
  22. Iteration:21,loss:0.2671571950273431,accuary:0.994560917721519
  23. Iteration:22,loss:0.2717032013880649,accuary:0.9955498417721519
  24. Iteration:23,loss:0.22175536570657886,accuary:0.9936708860759493
  25. Iteration:24,loss:0.38378527603103024,accuary:0.9921875
  26. Iteration:25,loss:0.25178228678203673,accuary:0.9948575949367089
  27. Iteration:26,loss:0.2930350419876417,accuary:0.9957476265822784
  28. Iteration:27,loss:0.23828124538560924,accuary:0.9910996835443038
  29. Iteration:28,loss:0.27899719219000174,accuary:0.9941653481012658
  30. Iteration:29,loss:0.19610462937318568,accuary:0.9971321202531646
  31. Iteration:30,loss:0.17896748130090284,accuary:0.9967365506329114
  32. Iteration:31,loss:0.20533486828761852,accuary:0.9956487341772152
  33. Iteration:32,loss:0.18936780759917962,accuary:0.9947587025316456
  34. Iteration:33,loss:0.19822018125834942,accuary:0.9971321202531646
  35. Iteration:34,loss:0.18711463891170313,accuary:0.997626582278481
  36. Iteration:35,loss:0.19687060752992383,accuary:0.9961431962025317
  37. Iteration:36,loss:0.21515094438084367,accuary:0.9978243670886076
  38. Iteration:37,loss:0.15841617382851192,accuary:0.9959454113924051
  39. Iteration:38,loss:0.1879000129068316,accuary:0.9971321202531646
  40. Iteration:39,loss:0.20046837705686385,accuary:0.9937697784810127
  41. Iteration:40,loss:0.19783038984205953,accuary:0.9980221518987342
  42. Iteration:41,loss:0.12775524038870417,accuary:0.9974287974683544
  43. Iteration:42,loss:0.1382192151944165,accuary:0.9977254746835443
  44. Iteration:43,loss:0.19454795052924856,accuary:0.9988132911392406
  45. Iteration:44,loss:0.1327811548294071,accuary:0.9966376582278481
  46. Iteration:45,loss:0.12060708776144188,accuary:0.9979232594936709
  47. Iteration:46,loss:0.1333126811136053,accuary:0.9968354430379747
  48. Iteration:47,loss:0.2438494389545342,accuary:0.9974287974683544
  49. Iteration:48,loss:0.14213542047096844,accuary:0.9972310126582279
  50. Iteration:49,loss:0.08317013298852893,accuary:0.9986155063291139

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/697776
推荐阅读
相关标签
  

闽ICP备14008679号