赞
踩
最近,研究了下如何使用基于tensorflow-hub中预训练bert,一开始找到的关于预模型使用介绍的官方教程国内打不开,所以看了很多博客遇到了很多坑,直至最后找到能打开的教程,才发现使用很简单。
实验版本:
tensorflow版本: 2.3.0
tensorflow-hub版本:0.9.0
python版本: 3.7.6
数据准备:
首先,熟悉bert的都知道输入有3个:inputIds、inputMask、segmentIds,这个不多说了,百度一大堆。
直接获取bert输出代码:
- max_seq_length = 256
-
- input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,),
- dtype=tf.int32,name="input_word_ids")
- input_mask = tf.keras.layers.Input(shape=(max_seq_length,),
- dtype=tf.int32,name="input_mask")
- segment_ids = tf.keras.layers.Input(shape=(max_seq_length,),
- dtype=tf.int32,name="segment_ids")
-
- # 将trainable值改为False
- module = hub.KerasLayer(BERT_URL,trainable=False)#,signature="token")
- pooled_output, sequence_output = module([input_mask,segment_ids,input_word_ids])
-
- # 构建模型输入输出
- model = tf.keras.Model(inputs=[input_word_ids,input_mask,segment_ids],outputs=[pooled_output,sequence_output])
-
- # 获取输出
- output = model.predict([inputIds,inputMask,segmentIds])
- # output输出结果 ----》 pool_out: shape=[batch, 768];sequence_out: shape=[batch, 256, 768]
-------------------------------------------------BUG----------------------------------------------
这里也尝试了参考链接3中博客方式获取bert输出结果,但是遇到个问题
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (2 total):
* False
* None:
- # 实验内容1——参数名来自https://hub.tensorflow.google.cn/tensorflow/bert_zh_L-12_H-768_A-12/2
- outputs,_ = hub_module(input_word_ids=tf.constant(tmp_inputids),
- input_mask=tf.constant(tmp_inputMask),
- segment_ids=tf.constant(tmp_segmentIds))
- ---------------------------------------------------------------------------
- ValueError Traceback (most recent call last)
- <ipython-input-45-b1533b83b191> in <module>
- 2 outputs,_ = hub_module(input_word_ids=tf.constant(tmp_inputids),
- 3 input_mask=tf.constant(tmp_inputMask),
- ----> 4 segment_ids=tf.constant(tmp_segmentIds))
- 5
- 6 # # 实验内容2——参数名来自报错提示
-
- /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
- 507
- 508 def _call_attribute(instance, *args, **kwargs):
- --> 509 return instance.__call__(*args, **kwargs)
- 510
- 511
-
- /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
- 778 else:
- 779 compiler = "nonXla"
- --> 780 result = self._call(*args, **kwds)
- 781
- 782 new_tracing_count = self._get_tracing_count()
-
- /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
- 812 # In this case we have not created variables on the first call. So we can
- 813 # run the first trace but we should fail if variables are created.
- --> 814 results = self._stateful_fn(*args, **kwds)
- 815 if self._created_variables:
- 816 raise ValueError("Creating variables on a non-first call to a function"
-
- /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。