赞
踩
讲模型保存的:https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
使用已经预训练好的模型,自己fine-tuning。
1、首先获得pre-traing的graph结构,saver = tf.train.import_meta_graph('my_test_model-1000.meta')
2、加载参数,saver.restore(sess,tf.train.latest_checkpoint('./'))
3、准备feed_dict,新的训练数据或者测试数据。这样就可以使用同样的模型,训练或者测试不同的数据。
4、如果想在已有的网络结构上添加新的层,如前面卷积网络,获得fc2时,然后添加了一个全连接层和输出层。
- pred_y = graph.get_tensor_by_name("fc2/add:0")
-
- ## add the new layers
- weights = tf.Variable(tf.truncated_normal([4, 6], stddev=0.1), name="w")
- biases = tf.Variable(tf.constant(0.1, shape=[6]), name="b")
- conv1 = tf.matmul(pred_y, weights) + biases
- output1 = tf.nn.softmax(conv1)
5、只要加载模型的前一部分,然后从后面开始fine-tuning。
- # pre-train and fine-tuning
- fc2 = graph.get_tensor_by_name("fc2/add:0")
- fc2 = tf.stop_gradient(fc2) # stop the gradient compute
- fc2_shape = fc2.get_shape().as_list()
-
- # fine -tuning
- new_nums = 6
- weights = tf.Variable(tf.truncated_normal([fc2_shape[1], new_nums], stddev=0.1), name="w")
- biases = tf.Variable(tf.constant(0.1, shape=[new_nums]), name="b")
- conv2 = tf.matmul(fc2, weights) + biases
- output2 = tf.nn.softmax(conv2)
知识点
1、.meta文件:一个协议缓冲,保存tensorflow中完整的graph、variables、operation、collection。
2、checkpoint文件:一个二进制文件,包含了weights, biases, gradients和其他variables的值。但是0.11版本后的都修改了,用.data和.index保存值,用checkpoint记录最新的记录。
3、在进行保存时,因为meta中保存的模型的graph,这个是一样的,只需保存一次就可以,所以可以设置saver.save(sess, 'my-model', write_meta_graph=False)
即可。
4、如果想设置每多长时间保存一次,可以设置saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)
,这个是每2个小时保存一次。
5、如果不想保存所有变量,可以在创建saver实例时,指定保存的变量,可以以list或者dict的类型保存。如:
- w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
- w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
- saver = tf.train.Saver([w1,w2])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。