赞
踩
代码展示:
- import tensorflow as tf
- import math
- from tensorflow.examples.tutorials.mnist import input_data
-
- def multilayer_perception():
- #准备数据
- mnist = input_data.read_data_sets(r'C:\Users\Administrator\Desktop\AI_project\tensorflow\MNIST_data', one_hot=True)
- batch_size = 128
- #给数据准备好placeholder
- x = tf.placeholder(tf.float32,[batch_size,784],name="x_placeholder")
- y = tf.placeholder(tf.float32,[batch_size,10],name="y_placeholder")
- #初始化参数,第一个隐藏层,第二个隐藏层,输出层
- n_input = 784
- n_hidden1,n_hidden2 = 256,256
- n_class = 10
- weight = {
- "h1":tf.Variable(tf.random_normal([n_input,n_hidden1],name="w1")),
- "h2":tf.Variable(tf.random_normal([n_hidden1,n_hidden2],name="w2")),
- "out":tf.Variable(tf.random_normal([n_hidden2,n_class],name="w"))
- }
- biase = {
- "h1":tf.Variable(tf.zeros([n_hidden1],name="b1")),
- "h2":tf.Variable(tf.zeros([n_hidden2],name="b2")),
- "out":tf.Variable(tf.zeros([n_class]),name="bias")
- }
- #构建网络,得到输出层的结果
- #第一个隐藏层
- layer_1 = tf.add(biase["h1"],tf.matmul(x,weight["h1"]))
- layer_1 = tf.nn.relu(layer_1)
- #第二个隐藏层
- layer_2 = tf.add(biase["h2"],tf.matmul(layer_1,weight["h2"]))
- layer_2 = tf.nn.relu(layer_2)
- pred = tf.add(tf.matmul(layer_2,weight["out"]),biase["out"])
- #构建损失函数和优化器
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=pred))
- train = tf.train.AdamOptimizer().minimize(loss)
- #模型预测
- prediction = tf.equal(tf.argmax(pred,axis=1),tf.argmax(y,axis=1))
- accuary = tf.reduce_mean(tf.cast(prediction,tf.float32))
- #初始化变量
- init = tf.global_variables_initializer()
- #定义session来运行计算
- with tf.Session() as sess:
- sess.run(init)
- writer = tf.summary.FileWriter("./logs",sess.graph)
- writer.close()
- n_batch_train = math.ceil(mnist.train.num_examples/batch_size)
- n_batch_test = math.ceil(mnist.test.num_examples/batch_size)
- for i in range(50):
- # 训练
- loss_total = 0
- for _ in range(n_batch_train):
- x_input, y_input = mnist.train.next_batch(batch_size)
- _, l = sess.run([train, loss], feed_dict={x: x_input, y: y_input})
- loss_total += l
- # 测试
- accuary_total = 0
- for _ in range(n_batch_test):
- x_input, y_input = mnist.train.next_batch(batch_size)
- accuary_total += sess.run(accuary, feed_dict={x: x_input, y: y_input})
- print("Iteration:{},loss:{},accuary:{}".format(i,loss_total/n_batch_train,accuary_total/n_batch_test))
效果展示:
- Iteration:0,loss:174.60768956916277,accuary:0.8553204113924051
- Iteration:1,loss:43.34331250301627,accuary:0.8945806962025317
- Iteration:2,loss:27.401089508588925,accuary:0.9154469936708861
- Iteration:3,loss:19.30703142500201,accuary:0.9252373417721519
- Iteration:4,loss:14.60925396680832,accuary:0.939873417721519
- Iteration:5,loss:11.22583605861941,accuary:0.9509493670886076
- Iteration:6,loss:8.509260208107705,accuary:0.962618670886076
- Iteration:7,loss:6.391588538559135,accuary:0.9598496835443038
- Iteration:8,loss:4.8431818048346535,accuary:0.9668710443037974
- Iteration:9,loss:3.9486298046364,accuary:0.9698378164556962
- Iteration:10,loss:2.9593446985000207,accuary:0.977254746835443
- Iteration:11,loss:2.2546944803228146,accuary:0.9822982594936709
- Iteration:12,loss:1.6660885298288413,accuary:0.9841772151898734
- Iteration:13,loss:1.2757584760756577,accuary:0.985067246835443
- Iteration:14,loss:1.0731323728177171,accuary:0.9853639240506329
- Iteration:15,loss:0.9020052603093415,accuary:0.9900118670886076
- Iteration:16,loss:0.6498902476481249,accuary:0.9878362341772152
- Iteration:17,loss:0.4907471142712625,accuary:0.9933742088607594
- Iteration:18,loss:0.5058308260880081,accuary:0.9913963607594937
- Iteration:19,loss:0.43065127376118356,accuary:0.9935719936708861
- Iteration:20,loss:0.4260103874590171,accuary:0.9930775316455697
- Iteration:21,loss:0.2671571950273431,accuary:0.994560917721519
- Iteration:22,loss:0.2717032013880649,accuary:0.9955498417721519
- Iteration:23,loss:0.22175536570657886,accuary:0.9936708860759493
- Iteration:24,loss:0.38378527603103024,accuary:0.9921875
- Iteration:25,loss:0.25178228678203673,accuary:0.9948575949367089
- Iteration:26,loss:0.2930350419876417,accuary:0.9957476265822784
- Iteration:27,loss:0.23828124538560924,accuary:0.9910996835443038
- Iteration:28,loss:0.27899719219000174,accuary:0.9941653481012658
- Iteration:29,loss:0.19610462937318568,accuary:0.9971321202531646
- Iteration:30,loss:0.17896748130090284,accuary:0.9967365506329114
- Iteration:31,loss:0.20533486828761852,accuary:0.9956487341772152
- Iteration:32,loss:0.18936780759917962,accuary:0.9947587025316456
- Iteration:33,loss:0.19822018125834942,accuary:0.9971321202531646
- Iteration:34,loss:0.18711463891170313,accuary:0.997626582278481
- Iteration:35,loss:0.19687060752992383,accuary:0.9961431962025317
- Iteration:36,loss:0.21515094438084367,accuary:0.9978243670886076
- Iteration:37,loss:0.15841617382851192,accuary:0.9959454113924051
- Iteration:38,loss:0.1879000129068316,accuary:0.9971321202531646
- Iteration:39,loss:0.20046837705686385,accuary:0.9937697784810127
- Iteration:40,loss:0.19783038984205953,accuary:0.9980221518987342
- Iteration:41,loss:0.12775524038870417,accuary:0.9974287974683544
- Iteration:42,loss:0.1382192151944165,accuary:0.9977254746835443
- Iteration:43,loss:0.19454795052924856,accuary:0.9988132911392406
- Iteration:44,loss:0.1327811548294071,accuary:0.9966376582278481
- Iteration:45,loss:0.12060708776144188,accuary:0.9979232594936709
- Iteration:46,loss:0.1333126811136053,accuary:0.9968354430379747
- Iteration:47,loss:0.2438494389545342,accuary:0.9974287974683544
- Iteration:48,loss:0.14213542047096844,accuary:0.9972310126582279
- Iteration:49,loss:0.08317013298852893,accuary:0.9986155063291139
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。