赞
踩
这篇文章介绍一下如何开始第一个tensorflow模型。
对于大部分机器学习模型来说,离线数据流主要可以分为特征抽取,特征预处理,模型训练,模型评估等几个部分。这里将会用简短的代码基于tensorflow实现一个简单的CNN模型。tensorflow是目前业界使用的十分广泛的深度学习框架,尽管pytorch大有后来者居上的趋势,但是tensorflow在业界应用上的王者地位依然难以动摇,特别是在生态构建和硬件支持上,pytorch暂时还没有展现出tensorflow一样的能力。tensorflow的学习和调试难度一直受人诟病,但是随着2.0版本keras的引入,框架的易用性大大提升,相信后续较长时间tensorflow依然是业界主流的深度学习框架。
1.训练数据
这里例程的选用比较简单的数据集MNIST,训练集共由四部分组成。
训练集共包含5500条样本,每条样本的图片大小为28*28。label我们这里选择one_hot编码形式,即用10个桶表示10个数字(例如数字5用one_hot表示为[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]),这种形式方便我们后续的loss计算。测试集包含1000条样本,每条样本的形式和训练集相同。
2.建立模型
tensorflow采用固化图的方式运行,主要由计算图(tf.Graph)负责,在运行前首先要构建计算图,计算图主要包含了张量(tf.Tensor)和节点(tf.Operation)。张量可以理解为数据管道,每个节点是对张量进行的变化处理。需要注意的是张量和数据并不是同一个概念,这里的张量可以理解为水管,是数据的通道,所以在没有灌入数据之前,计算图是不会直接得出准确结果的。
这里我们首先建立模型的主体,这个主体就是由计算图的形式表达的。
- def setup_model(input_x, input_y):
- """
- setup cnn model
- """
- # conv1 input_x shape: (batch_size, 28, 28, 1), filter number: 8 filter size: 3*3
- conv1 = tf.layers.conv2d(input_x, 8, 3, activation=tf.nn.relu)
- # 2*2 max pooling
- conv1 = tf.layers.max_pooling2d(conv1, 2, 2)
- conv2 = tf.layers.conv2d(conv1, 16, 3, activation=tf.nn.relu)
- conv2 = tf.layers.max_pooling2d(conv2, 2, 2)
- # flatten conv2 result
- flatten = tf.layers.flatten(conv2)
- # full connect
- dense1 = tf.layers.dense(flatten, 20, activation=tf.nn.relu)
- output = tf.layers.dense(dense1, 10, activation=tf.sigmoid)
-
- loss = loss_fuc(output, input_y)
- optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
- min_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
- accuracy = eval(output, input_y)
- return [loss, accuracy, min_op]
-
-
- def loss_fuc(y_predict, y_true):
- """
- loss function
- """
- return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=y_true))
-
-
- def eval(y_predict, y_true):
- """
- get accuracy
- """
- correct_prediction = tf.equal(tf.argmax(y_predict, 1), tf.argmax(y_true, 1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- return accuracy
模型输入为28*28*1的图片,主要包含了四层:第一个卷积层conv1,采用了最大池化。第二个卷积层和第一个卷积层结构类似,只是filter数量略有不同。第二层卷积之后,我们用flatten把所有卷积结果拉平为一维向量,然后经过了一层全连接,最后采用一层sigmoid作为数据结果。
loss_fuc定义了loss的计算方式。tf.nn.softmax_cross_entropy_with_logits用于计算分类结果的交叉熵loss,这里需要注意的是函数内部进行了suftmax处理,所以直接输入模型预测结果和标签就行,不需要再对模型预测结果进行softmax处理。
eval函数用于评估模型预测结果。
3.训练和测试
- def train_model(mnist):
- """
- train and test model
- """
- epoch = 100
- batch_size = 128
-
- input_x = tf.placeholder(tf.float32, [None, 28, 28, 1])
- input_y = tf.placeholder(tf.float32, [None, 10])
-
- [loss, accuracy, train_step] = setup_model(input_x, input_y)
- sess = tf.Session()
- sess.run(tf.global_variables_initializer())
-
- feature = mnist.train.images
- label = mnist.train.labels
- feature = feature.reshape((feature.shape[0], 28, 28, 1))
- steps = int(feature.shape[0] / batch_size)
-
- for i in range(epoch):
- loss_list = []
- accuracy_list = []
- for j in range(steps):
- batch = [feature[j: j + batch_size], label[j: j + batch_size]]
- loss_out, accuracy_out, _ = sess.run([loss, accuracy, train_step], feed_dict={input_x: batch[0], input_y: batch[1]})
- loss_list.append(loss_out)
- accuracy_list.append(accuracy_out)
- print ("epoch: {}, loss mean: {}, accuracy: {}".format(i, np.mean(np.array(loss_list)), np.mean(np.array(accuracy_list))))
-
- print ("test model")
- test_feature = mnist.test.images
- test_label = mnist.test.labels
- test_feature = test_feature.reshape((test_feature.shape[0], 28, 28, 1))
- test_accuracy = sess.run(accuracy, feed_dict={input_x: test_feature, input_y: test_label})
- print ("test accuracy: {}".format(test_accuracy))
训练和测试函数中我们定义了input_x和input_y两个tf.placeholder,这两个tensor可以看作是计算图接受数据流输入的接口。在计算图构建完成以后,我们定义了一个会话(tf.Session)。会话可以看作是执行计算图的指令入口,我们之前定义的计算图都是静态的,要让这个计算图run起来,就需要通过session传入相关的数据和指令。注意第一次训练前需要对计算图进行初始化操作。
setup_model函数返回了三个operation,分别是loss,accuracy,train_step。在运行计算图进行训练时,我们通过feed_dict传入训练或者测试用的数据,然后通过session传入我们想看到的结果(这里传入的是包含loss,accuracy,train_step的list)。这样session.run()函数就会运行计算图,并返回给我们想要的结果。
关于Session的几个注意事项:
1.在进行第一次训练之前记得使用tf.global_variables_initializer()初始化计算图内的所有参数。
2.在训练的for循环内切记不要修改计算图,任何增加计算图tensor或者operation的操作都可能导致训练过程中内存暴增。
3.一个会话被定义之后如果没有指定计算图就会使用默认计算图。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。