当前位置:   article > 正文

【Tensorflow】Fine-tuning_tensorflow start fine tuning

tensorflow start fine tuning

讲模型保存的:https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

fine-tuning

使用已经预训练好的模型,自己fine-tuning。

1、首先获得pre-traing的graph结构,saver = tf.train.import_meta_graph('my_test_model-1000.meta')

2、加载参数,saver.restore(sess,tf.train.latest_checkpoint('./'))

3、准备feed_dict,新的训练数据或者测试数据。这样就可以使用同样的模型,训练或者测试不同的数据。

4、如果想在已有的网络结构上添加新的层,如前面卷积网络,获得fc2时,然后添加了一个全连接层和输出层。

  1. pred_y = graph.get_tensor_by_name("fc2/add:0")
  2. ## add the new layers
  3. weights = tf.Variable(tf.truncated_normal([4, 6], stddev=0.1), name="w")
  4. biases = tf.Variable(tf.constant(0.1, shape=[6]), name="b")
  5. conv1 = tf.matmul(pred_y, weights) + biases
  6. output1 = tf.nn.softmax(conv1)

5、只要加载模型的前一部分,然后从后面开始fine-tuning。

  1. # pre-train and fine-tuning
  2. fc2 = graph.get_tensor_by_name("fc2/add:0")
  3. fc2 = tf.stop_gradient(fc2) # stop the gradient compute
  4. fc2_shape = fc2.get_shape().as_list()
  5. # fine -tuning
  6. new_nums = 6
  7. weights = tf.Variable(tf.truncated_normal([fc2_shape[1], new_nums], stddev=0.1), name="w")
  8. biases = tf.Variable(tf.constant(0.1, shape=[new_nums]), name="b")
  9. conv2 = tf.matmul(fc2, weights) + biases
  10. output2 = tf.nn.softmax(conv2)

知识点

1、.meta文件:一个协议缓冲,保存tensorflow中完整的graph、variables、operation、collection。

2、checkpoint文件:一个二进制文件,包含了weights, biases, gradients和其他variables的值。但是0.11版本后的都修改了,用.data和.index保存值,用checkpoint记录最新的记录。

3、在进行保存时,因为meta中保存的模型的graph,这个是一样的,只需保存一次就可以,所以可以设置saver.save(sess, 'my-model', write_meta_graph=False)即可。

4、如果想设置每多长时间保存一次,可以设置saver = tf.train.Saver(keep_checkpoint_every_n_hours=2),这个是每2个小时保存一次。

5、如果不想保存所有变量,可以在创建saver实例时,指定保存的变量,可以以list或者dict的类型保存。如:

  1. w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
  2. w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
  3. saver = tf.train.Saver([w1,w2])

转:【tensorflow】保存模型、再次加载模型等操作 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/732201
推荐阅读
相关标签
  

闽ICP备14008679号