当前位置:   article > 正文

BERT等语言模型的BertForMaskedLM避的坑

bertformaskedlm

在用transformers中的BertForMaskedLM来预测被mask掉的单词时一定要加特殊字符 [ C L S ] 和 [ S E P ] [CLS]和[SEP] [CLS][SEP]。不然效果很差很差!!!

from transformers import AlbertTokenizer, AlbertForMaskedLM
import torch

tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2', cache_dir='E:/Projects/albert/')
model = AlbertForMaskedLM.from_pretrained('E:/Projects/albert')

sentence = "It is a very beautiful book."
tokens = ['[CLS]'] + tokenizer.tokenize(sentence) + ['[SEP]']

# i就是被mask掉的id
for i in range(1, len(tokens)-1):
    tmp = tokens[:i] + ['[MASK]'] + tokens[i+1:]
    masked_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tmp)])
    segment_ids = torch.tensor([[0]*len(tmp)])

    outputs = model(masked_ids, token_type_ids=segment_ids)
    prediction_scores = outputs[0]
    print(tmp)
    # 打印被预测的字符
    prediction_index = torch.argmax(prediction_scores[0, i]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([prediction_index])[0]
    print(predicted_token)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/513814
推荐阅读
相关标签
  

闽ICP备14008679号