赞
踩
本文主要介绍BERT中BertForQuestionAnswering的简单使用,并附带一些有助于新手理解的网站
BERT模型下载地址
下载后解压,本文解压路径如下:
由于需要使用pytorch进行训练,所以需要将模型转换成pytorch可运行的形式。
转换的代码为:
import argparse import logging import torch from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert logging.basicConfig(level=logging.INFO) def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): # Initialise PyTorch model config = BertConfig.from_json_file(bert_config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = BertForPreTraining(config) # Load weights from tf checkpoint load_tf_weights_in_bert(model, config, tf_checkpoint_path) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path) if __name__ == "__main__": convert_tf_checkpoint_to_pytorch(r'E:\paperCode\BERT\uncased_L-12_H-768_A-12\bert_model.ckpt', r'E:\paperCode\BERT\uncased_L-12_H-768_A-12\bert_config.json', r'E:\paperCode\BERT\uncased_L-12_H-768_A-12\pytorch_model.bin')
模型转换成功后,将文件放入一个新的文件夹,config.json是Transformer要求的文件名,由bert_config重命名得到。
具体的问答代码如下:
import torch import os # 使用 transformers的模型库需要将bert模型目录下的 bert_config.json 改为 config.json from transformers import BertTokenizer,BertForQuestionAnswering # 这个要从 transformers 导入 bert_path = r"E:\paperCode\BERT\qa" ## 上面我们自定义的保留最后3个模型的路径 model = BertForQuestionAnswering.from_pretrained(bert_path) tokenizer = BertTokenizer.from_pretrained(bert_path) question, doc = "Who is Lyon" , "Lyon is a killer" encoding = tokenizer.encode_plus(text = question,text_pair = doc, verbose=False) inputs = encoding['input_ids'] #Token embeddings sentence_embedding = encoding['token_type_ids'] #Segment embeddings tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens print("tokens: ", tokens) outputs = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding])) #BertForQuestionAnswering返回一个QuestionAnsweringModelOutput对象。 #由于将BertForQuestionAnswering的输出设置为start_scores, end_scores, # 因此返回的QuestionAnsweringModelOutput对象被强制转换为字符串的元组('start_logits', 'end_logits'),从而导致类型不匹配错误。 # torch.argmax(dim)会返回dim维度上张量最大值的索引 start_index = torch.argmax(outputs.start_logits) end_index = torch.argmax(outputs.end_logits) # start_index = torch.argmax(start_scores) # end_index = torch.argmax(end_scores) #print("start_index:%d, end_index %d"%(start_index, end_index)) answer = ' '.join(tokens[start_index:end_index + 1]) print(answer) # 每次执行的结果不一致,这里因为模型没有经过训练,所以效果不好,输出结果不佳
注意点:config.json的重命名、start_index和end_index的格式问题。
huggingface
Hugging Face专注于NLP技术,拥有大型的开源社区。尤其是在github上开源的自然语言处理,预训练模型库 Transformers,已被下载超过一百万次,github上超过24000个star。Transformers 提供了NLP领域大量state-of-art的 预训练语言模型结构的模型和调用框架。以下是repo的链接(https://github.com/huggingface/transformers)
Transformer:
简介:
https://www.cnblogs.com/panchuangai/p/12567853.html
https://www.cnblogs.com/panchuangai/p/12567851.html
BERT介绍
https://zhuanlan.zhihu.com/p/113639892
本文在下文的基础上进行实验,对文中一些在现版本下不可运行的部分进行调整。
https://blog.csdn.net/weixin_46425692/article/details/109184688
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。