当前位置:   article > 正文

tensorflow——模型迁移_模型迁移 节点优化 节点合并 tensorflow

模型迁移 节点优化 节点合并 tensorflow

tensorflow的三种Graph结构:

  • Graph:tensorflow运行会话是在默认的Graph中,包含了各个运算节点和用于计算的张量;
  • GraphDef:将Graph序列化为python代码得到的图,可以理解为一种数据结构,以常量的形式保存Tensor,无法继续训练; ——对应pb文件
  • MetaGraph:将Graph进行序列化,进行模型保存,Tensor以变量形式保存,可以被继续训练(

    通过tf.export_meta_graph()保存Graph,得到MetaGraph

    通过import_meta_graph将MetaGraph恢复

    );           

对图的保存和恢复三对API,

  • tf.train.Saver()/saver.restore()
  • export_meta_graph/Import_meta_graph
  • tf.train.write_graph()/tf.Import_graph_def()

pb文件的导入

步骤1 构造 GraphDef 协议缓冲区:

 使用tf.Graph.as_graph_def 创建一个 GraphDef 原型

步骤3 将GraphDef 的内容导入默认图Graph中:

函数:tf.import_graph_def( graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None )

参数:

  • graph_def:一个 GraphDef 原型,它包含了要导入默认图形的操作.
  • input_map:将 graph_def 中的输入名称 (作为字符串) 映射到张量对象的字典.导入的图形中的指定输入张量的值将 re-mapped 到各自的幅值.
  • return_elements:包含将作为操作对象返回的 graph_def 中的操作名称的字符串列表或张量名称在 graph_def,将作为张量对象返回.
  • name:(可选)将在 graph_def 中预先处理的前缀.请注意,这不适用于导入的函数名称.默认为 "import".
  • op_dict:(可选)将 op 类型名称映射到 OpDefprotos 的字典.必须包含在 graph_def 中命名的每个 op 类型的 OpDef 原型.如果省略,请使用在全局注册表中注册的 OpDef 原型.
  • producer_op_list:(可选)由图的生产者使用的 (可能被剥掉的) OpDefs 的 OpList 原始列表.如果提供,attrs 的 ops 在 graph_def 而不在 op_dict,根据producer_op_list 的默认值 将被删除.这将允许后面的二进制文件所产生的一些 GraphDefs 被早期的二进制文件接受.

返回值:

  • 该函数返回来自导入图形的操作或张量对象的列表,对应于 return_elements 中的名称

代码:

  1. import tensorflow as tf
  2. # gfile模块定义在tensorflow/python/platform/gfile.py
  3. # 包含GFile、FastGFile和Open三个没有线程锁定的文件I/O包装器类
  4. from tensorflow.python.platform import gfile
  5. with tf.Session() as sess:
  6. # 使用FsatGFile类的构造函数返回一个FastGFile类
  7. with gfile.FastGFile("/home/jiangziyang/model/model.pb", 'rb') as f:
  8. graph_def = tf.GraphDef()
  9. # 使用FastGFile类的read()函数读取保存的模型文件,并以字符串形式
  10. # 返回文件的内容,之后通过ParseFromString()函数解析文件的内容
  11. graph_def.ParseFromString(f.read())
  12. # 使用import_graph_def()函数将graph_def中保存的计算图加载到当前图中
  13. # 原型import_graph_def(graph_def,input_map,return_elements,name,op_dict,
  14. # producer_op_list)
  15. result = tf.import_graph_def(graph_def, return_elements=["add:0"])
  16. print(sess.run(result))
  17. # 输出为[array([3.], dtype=float32)]

获取变量的方式:https://blog.csdn.net/zw__chen/article/details/82187324

pb文件的生成

步骤1 获取当前默认的计算图——get_default_graph()

步骤2 将graph图序列化为GraphDef图——as_graph_def()

步骤3 将变量替换为常量

convert_variables_to_constants(sess,input_graph_def,output_node_names,variable_names_whitelist=None,variable_names_blacklist=None)
参数:
  sess:会话
  input_graph_def: GraphDef图
  output_node_names: 要保存的计算图中计算结点的名称组成的字符串列表.
  variable_names_whitelist: 要转换为常量的变量名称的集合,默认是所有变量.
  variable_names_blacklist: 要省略转换为常量的变量列表.

返回:
  转换为常量的GraphDef图.

代码:

  1. import tensorflow as tf
  2. #graph_util模块定义在tensorflow/python/framework/graph_util.py
  3. from tensorflow.python.framework import graph_util
  4. a = tf.Variable(tf.constant(1.0, shape=[1]), name="a")
  5. b = tf.Variable(tf.constant(2.0, shape=[1]), name="b")
  6. result = a + b
  7. init_op = tf.global_variables_initializer()
  8. with tf.Session() as sess:
  9. sess.run(init_op)
  10. # 导出主要记录了TensorFlow计算图上节点信息的GraphDef部分
  11. # 使用get_default_graph()函数获取默认的计算图
  12. graph_def = tf.get_default_graph().as_graph_def()
  13. # convert_variables_to_constants()函数表示用相同值的常量替换计算图中所有变量,
  14. # 原型convert_variables_to_constants(sess,input_graph_def,output_node_names,
  15. # variable_names_whitelist, variable_names_blacklist)
  16. # 其中sess是会话,input_graph_def是具有节点的GraphDef对象,output_node_names
  17. # 是要保存的计算图中的计算节点的名称,通常为字符串列表的形式,variable_names_whitelist
  18. # 是要转换为常量的变量名称集合(默认情况下,所有变量都将被转换),
  19. # variable_names_blacklist是要省略转换为常量的变量名的集合。
  20. output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
  21. # 将导出的模型存入.pb文件
  22. with tf.gfile.GFile("/home/jiangziyang/model/model.pb", "wb") as f:
  23. # SerializeToString()函数用于将获取到的数据取出存到一个string对象中,
  24. # 然后再以二进制流的方式将其写入到磁盘文件中
  25. f.write(output_graph_def.SerializeToString())

 

 
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/911695
推荐阅读
相关标签
  

闽ICP备14008679号