赞
踩
Tensorflow读取预训练模型是模型训练中常见的操作,通常的应用的场景包括:
1)训练中断后需要重新开始,将保存之前的checkpoint(包括.data
.meta
.index
checkpoint
这四个文件),然后重新加载模型,从上次断点处继续训练或预测。实现方法如下:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints_path'))
# 如果checkpoint存在则加载断点之前的训练模型
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
.meta
文件来重构网络,并检查权重数据:# ref: http://blog.csdn.net/spylyt/article/details/71601174
import tensorflow as tf
def restore_from_meta(sess, graph, ckpt_path):
with graph.as_default():
ckpt = tf.train.get_checkpoint_state(ckpt_path)
if ckpt and ckpt.model_checkpoint_path:
print('Found checkpoint, try to restore...')
saver = tf.train.import_meta_graph(''.join([ckpt.model_checkpoint_path, '.meta']))
saver.restore(sess, ckpt.model_checkpoint_path)
#### Check
# 打印出网络中可训练的权重参数名
for var in tf.trainable_variables():
print(var)
# 根据变量名称读取所需变量权重,打印其形状
conv1_2_weights = graph.get_tensor_by_name('conv1_2/weights:0')
print(conv1_2_weights.shape)
if __name__ == '__main__':
ckpt_path = 'data\\VOC2007_images\\GAP_backup\\'
graph = tf.Graph()
sess = tf.Session(graph=graph)
restore_from_meta(sess, graph, ckpt_path)
2)用上述f.train.Saver().restore(sess, ckpt_path)
方法来加载模型会将所有权重全部加载,如果你希望加载指定几层权重(比如在做transfer learning的时候),可以通过下面方法实现:
# ref: http://blog.csdn.net/qq_25737169/article/details/78125061
sess = tf.Session()
var = tf.global_variables()
var_to_restore = [val for val in var if 'conv1' in val.name or 'conv2'in val.name]
saver = tf.train.Saver(var_to_restore )
saver.restore(sess, os.path.join(model_dir, model_name))
var_to_init = [val for val in var if 'conv1' not in val.name or 'conv2'not in val.name]
tf.initialize_variables(var_to_init)
这样,就值加载了conv1
和conv2
的权重,而其他层中权重则通过网络初始化的方式获得。
如果用tensorflowslim
来构建网络的话,操作会更加简单:
exclude = ['layer1', 'layer2']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, os.path.join(model_dir, model_name))
该操作是restore除了layer1
和layer2
之外的其他权重。
上述两种情形是自己搭建了网络结构图或者虽然自己没有网络结构图,但手头有对应的.meta
文件,可以重构网络。如果只是有该网络的预训练模型权重,希望借用该模型中的某些层的权重来初始化自己的网络。这时候面对的场景就不一样了。
一般的模型权重会保存成:
1)checkpoint权重(model_ckpt.data
)或者frozen graph(.pb
)格式的文件
model_ckpt.data
权重可通过tensorflow.python.pywrap_tensorflow.NewCheckpointReader
或tf.train.NewCheckpointReader
获取,二者的功能一样,下面以前者举例:import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model_ckpt.ckpt")
# 从checkpoint中读出数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
# 输出权重tensor名字和值
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key))
如果你想用某些层的pretrained权重来初始化你自己的网络,可以在sess中通过下面的操作完成:
with tf.variable_scope('', reuse = True):
sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))
.pb
权重图可通过下面的方法获取:import os
import tensorflow as tf
var1_name = 'pool_3/_reshape:0'
var2_name = 'data/content:0'
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile('model_graph.pb', 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
var1_tensor, var2_tensor = tf.import_graph_def(od_graph_def, return_elements = [var1_name, var2_name])
得到变量1 pool_3/_reshape:0
和变量2 data/content:0
的值。
2)如果预训练模型是由caffe model转过来的,这时候可能的格式为numpy文件(.npy
),它按照字典的方式来存储各层的权重数据。下面举一个例子说明:
import numpy as np
import cPickle
layer_name = "conv1"
with open('model.npy') as f:
pretrained_weight = cPickle.load(f)
layer = pretrained_weight[layer_name]
weights = layer[0]
biases = layer[1]
获取conv1
中的变量,然后将权重和偏置项分别取出,注意权重的顺序。
如果想用conv1
的权重和偏置来初始化自己网络中的conv1
层,则可用下面的方法:
with tf.Session() as sess:
with tf.variable_scope(layer_name, reuse=True):
for subkey, data in zip('weights', 'biases'), pretrained_weight[layer_name]:
sess.run(tf.get_variable(subkey).assign(data))
以上是tensorflow中常见的预训练模型操作,在实际操作过程中根据自己的场景需求选择使用。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。