赞
踩
起因是我的任务需求涉及到使用PoNet这个模型来进行完形填空任务,需要根据logits值获取填入[MASK]位置的词汇表中的top k个候选token
我本身只用过huggingface,按huggingface的方式掩码语言模型一般都会提供有xxxxForMaskedLM这样一个接口,我可以方便的获取词汇表中每个token被填入的logit值
但是PoNet这个模型比较特殊,他在huggingface上没有提供xxxxForMaskedLM,只提供了PoNetForSequenceClassification接口
于是我来到ModelScope,ModelScope虽然有一个fill-mask的pipeline,但是这个pipeline封装得太好了,这个pipeline返回的就是[MASK]位置被填入logit值最高的token后的字符串,我无法直接获取到每个token的logit,我看官方文档也没有提到有没有类似xxxxForMaskedLM的接口,无奈只能自己摸索。
几个小时后终于给我摸索出来了
其实ModelScope也是有类似xxxxForMaskedLM的接口的,只不过藏得比较深,而且官网上的文档没有提到
- from modelscope.models.nlp.task_models import ModelForFillMask
- from modelscope.models.nlp.ponet.tokenization import PoNetTokenizer
这时候只要像用huggingface那样用ModelScope就可以了
- plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
- plm.cuda()
- tokenizer = PoNetTokenizer.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
但是这么写又有问题了
我寻思我不是用ModelScope加载吗?为啥下载tokenizer的时候链接变成了从huggingface下载???
于是我就把tokenizer名称换成在huggingface上的名称
- plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
- plm.cuda()
- tokenizer = PoNetTokenizer.from_pretrained("chtan/ponet-base-uncased", trust_remote_code=True)
这样就成功了
完整代码:
- from modelscope.models.nlp.task_models import ModelForFillMask
- from modelscope.models.nlp.ponet.tokenization import PoNetTokenizer
- import torch
-
- plm = ModelForFillMask.from_pretrained("damo/nlp_ponet_fill-mask_english-base", trust_remote_code=True)
- plm.cuda()
- tokenizer = PoNetTokenizer.from_pretrained("chtan/ponet-base-uncased", trust_remote_code=True)
-
- inputs = tokenizer("I want a cup of " + tokenizer.mask_token + ", please.", return_tensors="pt").to("cuda")
- mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0].item()
-
- logits = plm(**inputs).logits
- mask_logits = logits[0][mask_token_index]
- topk_ids = torch.topk(mask_logits, k=5, sorted=True, largest=True)
- topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids.indices)
- print(topk_tokens)
运行结果:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。