当前位置:   article > 正文

基于tensorflow-hub使用预训练bert模型——简单易上手成功率百分百_hubbert模型

hubbert模型

最近,研究了下如何使用基于tensorflow-hub中预训练bert,一开始找到的关于预模型使用介绍的官方教程国内打不开,所以看了很多博客遇到了很多坑,直至最后找到能打开的教程,才发现使用很简单。

实验版本:

         tensorflow版本: 2.3.0

         tensorflow-hub版本:0.9.0

         python版本: 3.7.6

数据准备:

         首先,熟悉bert的都知道输入有3个:inputIds、inputMask、segmentIds,这个不多说了,百度一大堆。

直接获取bert输出代码:

  1. max_seq_length = 256
  2. input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,),
  3. dtype=tf.int32,name="input_word_ids")
  4. input_mask = tf.keras.layers.Input(shape=(max_seq_length,),
  5. dtype=tf.int32,name="input_mask")
  6. segment_ids = tf.keras.layers.Input(shape=(max_seq_length,),
  7. dtype=tf.int32,name="segment_ids")
  8. # 将trainable值改为False
  9. module = hub.KerasLayer(BERT_URL,trainable=False)#,signature="token")
  10. pooled_output, sequence_output = module([input_mask,segment_ids,input_word_ids])
  11. # 构建模型输入输出
  12. model = tf.keras.Model(inputs=[input_word_ids,input_mask,segment_ids],outputs=[pooled_output,sequence_output])
  13. # 获取输出
  14. output = model.predict([inputIds,inputMask,segmentIds])
  15. # output输出结果 ----》 pool_out: shape=[batch, 768];sequence_out: shape=[batch, 256, 768]

-------------------------------------------------BUG----------------------------------------------

这里也尝试了参考链接3中博客方式获取bert输出结果,但是遇到个问题

         ValueError: Could not find matching function to call loaded from the SavedModel. Got:
                           Positional arguments (2 total):
                         * False
                         * None:

  1. # 实验内容1——参数名来自https://hub.tensorflow.google.cn/tensorflow/bert_zh_L-12_H-768_A-12/2
  2. outputs,_ = hub_module(input_word_ids=tf.constant(tmp_inputids),
  3. input_mask=tf.constant(tmp_inputMask),
  4. segment_ids=tf.constant(tmp_segmentIds))
  1. ---------------------------------------------------------------------------
  2. ValueError Traceback (most recent call last)
  3. <ipython-input-45-b1533b83b191> in <module>
  4. 2 outputs,_ = hub_module(input_word_ids=tf.constant(tmp_inputids),
  5. 3 input_mask=tf.constant(tmp_inputMask),
  6. ----> 4 segment_ids=tf.constant(tmp_segmentIds))
  7. 5
  8. 6 # # 实验内容2——参数名来自报错提示
  9. /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
  10. 507
  11. 508 def _call_attribute(instance, *args, **kwargs):
  12. --> 509 return instance.__call__(*args, **kwargs)
  13. 510
  14. 511
  15. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
  16. 778 else:
  17. 779 compiler = "nonXla"
  18. --> 780 result = self._call(*args, **kwds)
  19. 781
  20. 782 new_tracing_count = self._get_tracing_count()
  21. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
  22. 812 # In this case we have not created variables on the first call. So we can
  23. 813 # run the first trace but we should fail if variables are created.
  24. --> 814 results = self._stateful_fn(*args, **kwds)
  25. 815 if self._created_variables:
  26. 816 raise ValueError("Creating variables on a non-first call to a function"
  27. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/532005
推荐阅读
相关标签
  

闽ICP备14008679号