当前位置:   article > 正文

Tensorflow在不同训练场景下读取和使用不同格式pretrained model的方法_tensorflow的pretrained怎么用

tensorflow的pretrained怎么用

不同应用场景分析与示例

Tensorflow读取预训练模型是模型训练中常见的操作,通常的应用的场景包括:

1)训练中断后需要重新开始,将保存之前的checkpoint(包括.data .meta .index checkpoint这四个文件),然后重新加载模型,从上次断点处继续训练或预测。实现方法如下:

  • 如果代码中已经构建好了网络结构图
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints_path'))
    # 如果checkpoint存在则加载断点之前的训练模型
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 如果用别人完整的checkpoint文件,但自己没有搭建网络结构的代码,可以通过.meta文件来重构网络,并检查权重数据:
# ref: http://blog.csdn.net/spylyt/article/details/71601174
import tensorflow as tf 

def restore_from_meta(sess, graph, ckpt_path):
    with graph.as_default():
        ckpt = tf.train.get_checkpoint_state(ckpt_path)
        if ckpt and ckpt.model_checkpoint_path:
            print('Found checkpoint, try to restore...')
            saver = tf.train.import_meta_graph(''.join([ckpt.model_checkpoint_path, '.meta']))
            saver.restore(sess, ckpt.model_checkpoint_path)

        #### Check
        # 打印出网络中可训练的权重参数名
        for var in tf.trainable_variables():
            print(var)
        # 根据变量名称读取所需变量权重,打印其形状
        conv1_2_weights = graph.get_tensor_by_name('conv1_2/weights:0')
        print(conv1_2_weights.shape)

if __name__ == '__main__':
    ckpt_path = 'data\\VOC2007_images\\GAP_backup\\'
    graph = tf.Graph()
    sess = tf.Session(graph=graph)
    restore_from_meta(sess, graph, ckpt_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

2)用上述f.train.Saver().restore(sess, ckpt_path)方法来加载模型会将所有权重全部加载,如果你希望加载指定几层权重(比如在做transfer learning的时候),可以通过下面方法实现:

# ref: http://blog.csdn.net/qq_25737169/article/details/78125061
sess = tf.Session()
var = tf.global_variables()
var_to_restore = [val  for val in var if 'conv1' in val.name or 'conv2'in val.name]
saver = tf.train.Saver(var_to_restore )
saver.restore(sess, os.path.join(model_dir, model_name))
var_to_init = [val  for val in var if 'conv1' not in val.name or 'conv2'not in val.name]
tf.initialize_variables(var_to_init)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

这样,就值加载了conv1conv2的权重,而其他层中权重则通过网络初始化的方式获得。

如果用tensorflowslim来构建网络的话,操作会更加简单:

exclude = ['layer1', 'layer2']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, os.path.join(model_dir, model_name))
  • 1
  • 2
  • 3
  • 4

该操作是restore除了layer1layer2之外的其他权重。


上述两种情形是自己搭建了网络结构图或者虽然自己没有网络结构图,但手头有对应的.meta文件,可以重构网络。如果只是有该网络的预训练模型权重,希望借用该模型中的某些层的权重来初始化自己的网络。这时候面对的场景就不一样了。

一般的模型权重会保存成:

1)checkpoint权重(model_ckpt.data)或者frozen graph(.pb)格式的文件

  • model_ckpt.data权重可通过tensorflow.python.pywrap_tensorflow.NewCheckpointReadertf.train.NewCheckpointReader 获取,二者的功能一样,下面以前者举例:
import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join(model_dir, "model_ckpt.ckpt")
# 从checkpoint中读出数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
# 输出权重tensor名字和值
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

如果你想用某些层的pretrained权重来初始化你自己的网络,可以在sess中通过下面的操作完成:

with tf.variable_scope('', reuse = True):
        sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))
  • 1
  • 2
  • .pb权重图可通过下面的方法获取:
import os
import tensorflow as tf

var1_name = 'pool_3/_reshape:0'
var2_name = 'data/content:0'

graph = tf.Graph()
with graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile('model_graph.pb', 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        var1_tensor, var2_tensor = tf.import_graph_def(od_graph_def, return_elements = [var1_name, var2_name])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

得到变量1 pool_3/_reshape:0和变量2 data/content:0的值。

2)如果预训练模型是由caffe model转过来的,这时候可能的格式为numpy文件(.npy),它按照字典的方式来存储各层的权重数据。下面举一个例子说明:

import numpy as np
import cPickle

layer_name = "conv1"
with open('model.npy') as f:
    pretrained_weight = cPickle.load(f)
    layer = pretrained_weight[layer_name]
    weights = layer[0]
    biases = layer[1]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

获取conv1中的变量,然后将权重和偏置项分别取出,注意权重的顺序。

如果想用conv1 的权重和偏置来初始化自己网络中的conv1层,则可用下面的方法:

with tf.Session() as sess:
    with tf.variable_scope(layer_name, reuse=True):
        for subkey, data in zip('weights', 'biases'), pretrained_weight[layer_name]:
            sess.run(tf.get_variable(subkey).assign(data))
  • 1
  • 2
  • 3
  • 4

总结

以上是tensorflow中常见的预训练模型操作,在实际操作过程中根据自己的场景需求选择使用。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/189639
推荐阅读
相关标签
  

闽ICP备14008679号