当前位置:   article > 正文

试试在transformers中调用ERNIE

ernieformaskedlm

ERNIE是什么

a27e4910b147296d89bd962a30cf5660.jpeg
ERNIE发展路径

文心大模型ERNIE是百度发布的产业级知识增强大模型,涵盖了NLP大模型和跨模态大模型。在中文场景下,ERNIE有明显的优势,目前已经发布了一系列模型,包括ERNIE1.0, ERNIE2.0, ERNIE3.0, ERNIE-Gram, ERNIE-Doc等,并且一直在持续更新中。ERNIE官方的代码和模型是PaddlePaddle版本的,具体可以参见:PaddlePaddle/ERNIEPaddlePaddle/PaddleNLP 这两个repo。

在transformers中调用ERNIE

但是目前学术界和工业界在NLP大模型方面一般都会基于 huggingface/transformers 来开展工作。所以就会产生很强烈的动机将ERNIE从PaddlePaddle版本转换到Pytorch版本,这里的转换包括模型的代码和模型的参数。经过漫长的code review,本人提交的ERNIE模型最终合入了transformers中,并在4.22.0版本中可以直接体验了!!

250faa0135eb5818de55446fe7c5ea47.png
https://github.com/huggingface/transformers/releases/tag/v4.22.0

快速开始

首先将`transformers`升级到4.22.0版本及以上

pip install --upgrade transformers

仅需3行代码即可快速调用ERNIE模型(以ernie1.0为例):

  1. from transformers import BertTokenizer, ErnieModel
  2. tokenizer = BertTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
  3. model = ErnieModel.from_pretrained("nghuyong/ernie-1.0-base-zh")

目前已经支持模型,包括ERNIE1.0,ERNIE2.0,ERNIE3.0,ERNIE-gram,ERNIE-health在内的10个模型:

c168d4a00d87d56f6731bbd0c13a7cc0.jpeg
https://huggingface.co/nghuyong

输出验证

为了快速检验我们模型转换的是否正确,我们可以将官方PaddlePaddle版本的结果与我们转换后Pytorch版本的结果做一个对比验证。均对`welcome to ernie pytorch project`进行编码,并打印最后的pooled层

  1. import paddle
  2. import transformers
  3. # huggingface/transformers
  4. tokenizer = transformers.BertTokenizer.from_pretrained('nghuyong/ernie-1.0-base-zh')
  5. model = transformers.ErnieModel.from_pretrained('nghuyong/ernie-1.0-base-zh')
  6. input_ids = torch.tensor([tokenizer.encode(text="welcome to ernie pytorch project", add_special_tokens=True)])
  7. model.eval()
  8. with torch.no_grad():
  9. pooled_output = model(input_ids)[0]
  10. print('huggingface result')
  11. print('pool output:', pooled_output.numpy())
  12. # paddlepaddle
  13. tokenizer = paddlenlp.transformers.AutoTokenizer.from_pretrained("ernie-1.0-base-zh")
  14. model = paddlenlp.transformers.AutoModel.from_pretrained("ernie-1.0-base-zh")
  15. inputs = tokenizer("welcome to ernie pytorch project")
  16. inputs = {k: paddle.to_tensor([v]) for (k, v) in inputs.items()}
  17. model.eval()
  18. with paddle.no_grad():
  19. pooled_output = model(**inputs)
  20. print('paddle result')
  21. print('pool output:', pooled_output[0].numpy())

下面是计算的结果

  1. huggingface result
  2. pool output: [-1. -1. 0.9981035 -0.9996652 -0.78173476 -1. -0.9994901 0.97012603 0.85954666 0.9854131 ]
  3. paddle result
  4. pool output: [-0.99999976 -0.99999976 0.9981028 -0.9996651 -0.7815545 -0.99999976 -0.9994898 0.97014064 0.8594844 0.985419 ]

可以看到转后前后两者计算的结果保持一致(精确度为0.0001)

论文复现

我们知道ERNIE1.0的特点是融入了实体的信息。现在可以基于转换后的代码和模型,快速验证论文中的例子。

  1. import transformers
  2. tokenizer = transformers.BertTokenizer.from_pretrained('nghuyong/ernie-1.0-base-zh')
  3. model = transformers.ErnieForMaskedLM.from_pretrained('nghuyong/ernie-1.0-base-zh')
  4. input_ids = torch.tensor([tokenizer.encode(text="[MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。",
  5. add_special_tokens=True)])
  6. model.eval()
  7. with torch.no_grad():
  8. predictions = model(input_ids)[0][0]
  9. predicted_index = [torch.argmax(predictions[i]).item() for i in range(predictions.shape[0])]
  10. predicted_token = [tokenizer._convert_id_to_token(predicted_index[i]) for i in
  11. range(1, (predictions.shape[0] - 1))]
  12. print('predict result:\t', predicted_token)

预测的结果为:

predict result:	 ['西', '游', '记', '是', '中', '国', '神', '魔', '小', '说', '的', '经', '典', '之', '作', ',', '与', '《', '三', '国', '演', '义', '》', '《', '水', '浒', '传', '》', '《', '红', '楼', '梦', '》', '并', '称', '为', '中', '国', '古', '典', '四', '大', '名', '著', '。']

进一步对比其他的中文预训练模型:

  1. input:
  2. [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。
  3. output:
  4. {
  5. "bert-base-chinese": "《 神 》",
  6. "hfl/chinese-bert-wwm": "天 神 奇",
  7. "nghuyong/ernie-1.0-base-zh": "西 游 记"
  8. }

可以看到ERNIE模型能准确预测出「西游记」,在实体理解场景下ERNIE确实具备显著的优势!


最后放一下模型转换和测试的代码,欢迎star :)

GitHub - nghuyong/ERNIE-Pytorch: ERNIE Pytorch Version

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/352080
推荐阅读
相关标签
  

闽ICP备14008679号