赞
踩
tensorflow的三种Graph结构:
通过tf.export_meta_graph()保存Graph,得到MetaGraph
通过import_meta_graph将MetaGraph恢复
);对图的保存和恢复三对API,
tf.train.Saver()/saver.restore()
export_meta_graph/Import_meta_graph
tf.train.write_graph()/tf.Import_graph_def()
步骤1 构造 GraphDef 协议缓冲区:
使用tf.Graph.as_graph_def 创建一个 GraphDef 原型
步骤3 将GraphDef 的内容导入默认图Graph中:
函数:tf.import_graph_def( graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None )
参数:
返回值:
代码:
- import tensorflow as tf
- # gfile模块定义在tensorflow/python/platform/gfile.py
- # 包含GFile、FastGFile和Open三个没有线程锁定的文件I/O包装器类
- from tensorflow.python.platform import gfile
-
- with tf.Session() as sess:
- # 使用FsatGFile类的构造函数返回一个FastGFile类
- with gfile.FastGFile("/home/jiangziyang/model/model.pb", 'rb') as f:
- graph_def = tf.GraphDef()
- # 使用FastGFile类的read()函数读取保存的模型文件,并以字符串形式
- # 返回文件的内容,之后通过ParseFromString()函数解析文件的内容
- graph_def.ParseFromString(f.read())
-
- # 使用import_graph_def()函数将graph_def中保存的计算图加载到当前图中
- # 原型import_graph_def(graph_def,input_map,return_elements,name,op_dict,
- # producer_op_list)
- result = tf.import_graph_def(graph_def, return_elements=["add:0"])
-
- print(sess.run(result))
- # 输出为[array([3.], dtype=float32)]
获取变量的方式:https://blog.csdn.net/zw__chen/article/details/82187324
步骤1 获取当前默认的计算图——get_default_graph()
步骤2 将graph图序列化为GraphDef图——as_graph_def()
步骤3 将变量替换为常量
convert_variables_to_constants(sess,input_graph_def,output_node_names,variable_names_whitelist=None,variable_names_blacklist=None)
参数: sess:会话 input_graph_def: GraphDef图 output_node_names: 要保存的计算图中计算结点的名称组成的字符串列表. variable_names_whitelist: 要转换为常量的变量名称的集合,默认是所有变量. variable_names_blacklist: 要省略转换为常量的变量列表. 返回: 转换为常量的GraphDef图.
代码:
- import tensorflow as tf
- #graph_util模块定义在tensorflow/python/framework/graph_util.py
- from tensorflow.python.framework import graph_util
-
- a = tf.Variable(tf.constant(1.0, shape=[1]), name="a")
- b = tf.Variable(tf.constant(2.0, shape=[1]), name="b")
- result = a + b
- init_op = tf.global_variables_initializer()
-
- with tf.Session() as sess:
- sess.run(init_op)
-
- # 导出主要记录了TensorFlow计算图上节点信息的GraphDef部分
- # 使用get_default_graph()函数获取默认的计算图
- graph_def = tf.get_default_graph().as_graph_def()
-
- # convert_variables_to_constants()函数表示用相同值的常量替换计算图中所有变量,
- # 原型convert_variables_to_constants(sess,input_graph_def,output_node_names,
- # variable_names_whitelist, variable_names_blacklist)
- # 其中sess是会话,input_graph_def是具有节点的GraphDef对象,output_node_names
- # 是要保存的计算图中的计算节点的名称,通常为字符串列表的形式,variable_names_whitelist
- # 是要转换为常量的变量名称集合(默认情况下,所有变量都将被转换),
- # variable_names_blacklist是要省略转换为常量的变量名的集合。
- output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
-
- # 将导出的模型存入.pb文件
- with tf.gfile.GFile("/home/jiangziyang/model/model.pb", "wb") as f:
- # SerializeToString()函数用于将获取到的数据取出存到一个string对象中,
- # 然后再以二进制流的方式将其写入到磁盘文件中
- f.write(output_graph_def.SerializeToString())
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。