赞
踩
Tensorflow于1.7之后推出了tensorflow hub,其是一个适合于迁移学习的部分,主要通过将tensorflow的训练好的模型进行模块划分,并可以再次加以利用。不过介于推出不久,目前只有图像的分类和文本的分类以及少量其他模型
这里先通过几个简单的例子,来展示该hub的使用流程。
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf import tensorflow_hub as hub def half_plus_two(): '''该函数主要是创建一个简单的模型,其网络结构就是y = a*x + b ''' # 创建两个变量,a和b,如网络的权重和偏置 a = tf.get_variable('a', shape=[]) b = tf.get_variable('b', shape=[]) # 创建一个占位变量,为后面graph的输入提供准备 x = tf.placeholder(tf.float32) # 创建一个完整的graph y = a*x + b # 通过hub的add_signature,建立hub需要的网络 hub.add_signature('function1',inputs=x, outputs=y) y = a * x hub.add_signature('function2', inputs=x, outputs=y) def export_module(path): '''该函数用于调用创建api进行module创建,然后进行网络的权重赋值,最后通过session进行运行权重初始化,并最后输出该module''' # 通过hub的create_module_spec,接收函数建立一个Module spec = hub.create_module_spec(half_plus_two) # 防止串graph,将当期的操作放入同一个graph中 with tf.Graph().as_default(): # 通过hub的Module读取一个模块,该模块可以是url链接,表示从tensorflow hub去拉取, # 或者接收上述创建好的module module = hub.Module(spec) # 这里演示如何将权重值赋予到graph中的变量,如从checkpoint中进行变量恢复等 init_a = tf.assign(module.variable_map['a'], 0.5) init_b = tf.assign(module.variable_map['b'], 2.0) init_vars = tf.group([init_a, init_b]) with tf.Session() as sess: # 运行初始化,为了将其中变量的值设置为赋予的值 sess.run(init_vars) # 将模型导出到指定路径 module.export(path,sess) if __name__ == '__main__': export_module("./module")
运行上述代码,可得
可以看出,该例子中,生成一个Module是
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf import tensorflow_hub as hub def testExportTool1(self): # 指定module的文件夹位置,这里是export module_path = os.path.join('.','module') with tf.Graph().as_default(): # 读取当前存在的一个module m = hub.Module(module_path) print('signature',m.get_signature_names()) # 如直接采用y=f(x) 一样进行调用, output1= m([10,3,4], signature='function1', as_dict=True) output2 = m([10, 3, 4], signature='function2') with tf.Session() as sess: # 惯例进行全局变量初始化 sess.run(tf.initializers.global_variables()) # 观察生成的值是否与预定义值一致,即prediction是否与label一致 print(sess.run(output1)['default']) print(sess.run(output2)) self.assertAllEqual(sess.run(output1)['default'], [7, 3.5, 4]) self.assertAllEqual(sess.run(output2), [5, 1.5, 2]) if __name__ == '__main__': testExportTool1()
对于调用来说,就十分简单了
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。