当前位置:   article > 正文

基于ModelScope获取[MASK]位置的top k个候选token_modelscope token 是什么

modelscope token 是什么

起因是我的任务需求涉及到使用PoNet这个模型来进行完形填空任务,需要根据logits值获取填入[MASK]位置的词汇表中的top k个候选token

我本身只用过huggingface,按huggingface的方式掩码语言模型一般都会提供有xxxxForMaskedLM这样一个接口,我可以方便的获取词汇表中每个token被填入的logit值

但是PoNet这个模型比较特殊,他在huggingface上没有提供xxxxForMaskedLM,只提供了PoNetForSequenceClassification接口

于是我来到ModelScope,ModelScope虽然有一个fill-mask的pipeline,但是这个pipeline封装得太好了,这个pipeline返回的就是[MASK]位置被填入logit值最高的token后的字符串,我无法直接获取到每个token的logit,我看官方文档也没有提到有没有类似xxxxForMaskedLM的接口,无奈只能自己摸索。

几个小时后终于给我摸索出来了

其实ModelScope也是有类似xxxxForMaskedLM的接口的,只不过藏得比较深,而且官网上的文档没有提到

  1. from modelscope.models.nlp.task_models import ModelForFillMask
  2. from modelscope.models.nlp.ponet.tokenization import PoNetTokenizer

 这时候只要像用huggingface那样用ModelScope就可以了

  1. plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
  2. plm.cuda()
  3. tokenizer = PoNetTokenizer.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)

但是这么写又有问题了

我寻思我不是用ModelScope加载吗?为啥下载tokenizer的时候链接变成了从huggingface下载???

 于是我就把tokenizer名称换成在huggingface上的名称

  1. plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
  2. plm.cuda()
  3. tokenizer = PoNetTokenizer.from_pretrained("chtan/ponet-base-uncased", trust_remote_code=True)

这样就成功了

完整代码:

  1. from modelscope.models.nlp.task_models import ModelForFillMask
  2. from modelscope.models.nlp.ponet.tokenization import PoNetTokenizer
  3. import torch
  4. plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
  5. plm.cuda()
  6. tokenizer = PoNetTokenizer.from_pretrained("chtan/ponet-base-uncased", trust_remote_code=True)
  7. inputs = tokenizer("I want a cup of " + tokenizer.mask_token + ", please.", return_tensors="pt").to("cuda")
  8. mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0].item()
  9. logits = plm(**inputs).logits
  10. mask_logits = logits[0][mask_token_index]
  11. topk_ids = torch.topk(mask_logits, k=5, sorted=True, largest=True)
  12. topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids.indices)
  13. print(topk_tokens)

运行结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/189644
推荐阅读
相关标签
  

闽ICP备14008679号