赞
踩
bert、roberta、ernie在中文mlm任务上效果查看
- # -*- coding: utf-8 -*-
- import torch
- from transformers import BertTokenizer, BertForMaskedLM
-
-
- def get_mlm_model(list_):
- ret = []
- for path in list_:
- tokenizer = BertTokenizer.from_pretrained(path)
- model = BertForMaskedLM.from_pretrained(path)
- ret.append((path, tokenizer, model))
- return ret
-
-
- def gen_text(input_tx, tokenizer, model):
- tokenized_text = tokenizer.tokenize(input_tx)
- indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
-
- tokens_tensor = torch.tensor([indexed_tokens])
- segments_tensors = torch.tensor([[0] * len(tokenized_text)])
-
- with torch.no_grad():
- outputs = model(tokens_tensor, token_type_ids=segments_tensors)
- predictions = outputs[0]
-
- predicted_index = [torch.argmax(predictions[0, i]).item() for i in range(0, (len(tokenized_text) - 1))]
- predicted_token = [tokenizer.convert_ids_to_tokens([predicted_index[x]])[0] for x in
- range(1, (len(tokenized_text) - 1))]
- predicted_token = ''.join(predicted_token)
- print('raw token is:', input_tx)
- print('Predicted token is:', predicted_token)
- return predicted_token
-
-
- if __name__ == '__main__':
- list_ = get_mlm_model([
- 'bert-base-chinese',
- 'nghuyong/ernie-1.0',
- 'hfl/chinese-roberta-wwm-ext',
- # 'voidful/albert_chinese_tiny', # albert有点问题,有些层没参数,使用的是初始化参数
- ])
- inputs = [
- "[CLS]清华大学[MASK][MASK]在哪里[SEP]",
- "[CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]",
- "[CLS][MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]",
- "[CLS]今天的股票会[MASK]吗[SEP]",
- "[CLS]今天的股票会[MASK][MASK][SEP]",
- ]
- for input_ in inputs:
- for name, tokenizer, model in list_:
- print(name)
- gen_text(input_, tokenizer, model)
- print()
结果
bert-base-chinese
raw token is: [CLS]清华大学[MASK][MASK]在哪里[SEP]
Predicted token is: 。华大学校址在哪里
nghuyong/ernie-1.0
raw token is: [CLS]清华大学[MASK][MASK]在哪里[SEP]
Predicted token is: 清华大学大华在哪里
hfl/chinese-roberta-wwm-ext
raw token is: [CLS]清华大学[MASK][MASK]在哪里[SEP]
Predicted token is: 清华大学究底在哪里
bert-base-chinese
raw token is: [CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 《庸》是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。
nghuyong/ernie-1.0
raw token is: [CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 西游记是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。
hfl/chinese-roberta-wwm-ext
raw token is: [CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 西游梦是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。
bert-base-chinese
raw token is: [CLS][MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 《庸》是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。
nghuyong/ernie-1.0
raw token is: [CLS][MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 西游记是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。
hfl/chinese-roberta-wwm-ext
raw token is: [CLS][MASK][MASK][MASK]是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]
Predicted token is: 西游梦是中国神魔小说的经典之作,与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。
bert-base-chinese
raw token is: [CLS]今天的股票会[MASK]吗[SEP]
Predicted token is: 。天的股票会跌?
nghuyong/ernie-1.0
raw token is: [CLS]今天的股票会[MASK]吗[SEP]
Predicted token is: 今天的股票会涨吗
hfl/chinese-roberta-wwm-ext
raw token is: [CLS]今天的股票会[MASK]吗[SEP]
Predicted token is: 今天的股票会涨吗
bert-base-chinese
raw token is: [CLS]今天的股票会[MASK][MASK][SEP]
Predicted token is: 。天的。票会吗?
nghuyong/ernie-1.0
raw token is: [CLS]今天的股票会[MASK][MASK][SEP]
Predicted token is: 今天的股票会怎样
hfl/chinese-roberta-wwm-ext
raw token is: [CLS]今天的股票会[MASK][MASK][SEP]
Predicted token is: 今天的股票会涨吗
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。