当前位置:   article > 正文

手把手教利用Bert实现知识库问答(详细注释)_使用bert模型实现问答助手

使用bert模型实现问答助手

跟着DataFountain学的,加了一些代码注释,DataFountain有数据集,链接:个人工作平台https://work.datafountain.cn/forum?id=121&type=2&source=1

一、数据处理

1、数据分析

数据格式如下:这个题目的意思就是判断Question和Sentence是否匹配,如果匹配label就是1

  1. # 统一导入工具包
  2. import csv
  3. import transformers
  4. import torch
  5. import warnings
  6. warnings.filterwarnings('ignore')
  7. """
  8. 1.数据说明
  9. """
  10. # pd_table = pd.read_csv('./datasets/raw/WikiQA-train.tsv', encoding="utf-8", sep='\t')
  11. """
  12. 每条数据为以下6个字段:
  13. QuestionID: 问题id
  14. Question: 问题文本
  15. DocumentID: 检索到的作为答案来源的文档ID
  16. Document: 检索到的作为答案来源的文档标题
  17. SentenceID: 对于文档摘要中的每个句子的id
  18. Sentence: 文档中摘要中的句子
  19. label: 判断句子是否是答案的标签
  20. """
  21. """

2、数据加载处理

接下来加载有用的数据,将数据加载为<question, answer, label>这样的三元组,如果answer是question的正确答案,则lable为1,每个三元组用一个字典来存储。

  1. def load(filename):
  2. result = []
  3. with open(filename, mode='r', encoding="utf-8") as csvfile:
  4. spamreader = csv.reader(csvfile, delimiter='\t', quotechar='"') # 这里的spamreader装了csv中每一行的
  5. next(spamreader, None) # 这里自动迭代了一次,就跳过了第一行(第一行是标题),否则会读取到第一行
  6. for row in spamreader: # 对每一行进行遍历
  7. res = {} # 每一行的三元组用这个字典来存储
  8. res['question'] = str(row[1]) # 将Question赋值给字典的question
  9. res['answer'] = str(row[5]) # 将Sentence赋值给字典的answer
  10. res['label'] = int(row[6]) # 将label赋值给字典的label
  11. result.append(res) # 这里的res就是装了每一行需要用到的信息的字典,然后把这个字典放到result这个列表中,每一行对应的字典都添加给这个列表
  12. return result
  13. train_file = load('./datasets/raw/WikiQA-train.tsv')
  14. valid_file = load('./datasets/raw/WikiQA-dev.tsv')
  15. test_file = load('./datasets/raw/WikiQA-test.tsv')

train_file中的数据如下:

[{'question': 'how are glacier caves formed?', 'answer': 'A partly submerged glacier cave on Perito Moreno Glacier .', 'label': 0},[......],[.....]....]

3、数据标准化

读入数据后,要分词、将自然语言转换为one-hot向量,将文本对齐或者截断为相同长度等。同时需要将数据处理为Bert需要的输入形式。

  1. # 3.1 max_length就是自己设定的文本最大长度,然后需要对小于max_length的文本长度进行补齐,对于长度不足的序列,在右边补padding
  2. def padding(squence, max_length, pad_token=0):
  3. # squence就是需要需要处理的文本,pad_token就是padding的token,默认为0
  4. # padding_length如果大于0,就代表需要补padding_length个pad_token,如果小于0,代表需要截断多少个token
  5. padding_length = max_length - len(squence)
  6. return squence + [pad_token] * padding_length
  7. # 3.2 Bert的标准输入,Bert的输入主要由input_ids, input_mask, token_type_ids三部分构成
  8. # input_ids:Bert的输入通常需要两个句子,句子A前由[CLS]开始,以[SEP]结束,后面再连接句子B
  9. # input_mask: 由于不同批次的数据长度不同,因此会对数据进行补全。但补全的信息对于网络是无用的,这个主要是输入的句子可能存在填0的操作,attention模块不需要把填0的无意义的信息算进来,所以使用mask操作。
  10. # token_type_ids:用于标记一个input_ids序列中哪些位置是第一句话,哪些位置是第二句话。

将文本转换为可运行的数据集:

  1. def tokenize(data, max_length, device):
  2. # 下面使用transformer中的BertTokenizer进行处理,tokenizer需要指定预训练模型位置,读取使用的词表vocab用于文本转换。
  3. model_path = './datasets/models/bert-pretrain'
  4. tokenizer = transformers.BertTokenizer.from_pretrained(model_path, do_lower_case=True)
  5. res = []
  6. for triple in data:
  7. # 这里将问题和答案两句话作为输入,使用[sep]将两句话链接, 同时转化为one-hot向量
  8. # tokenizer的encode_plus方法除了可以对文本进行分词外,还可以将输入序列转换为上面描述的标准形式
  9. inputs = tokenizer.encode_plus(
  10. triple['question'], # 输入bert的第一句话,在sep前面
  11. triple['answer'], # 输入bert的第二句话,在sep后面,注意这句话的结尾还有个sep
  12. add_special_tokens=True, # 设为True就可以将句子专为Bert对应的输入形式
  13. max_length=max_length, # 指定序列的最大长度,超过长度会截断
  14. truncation=True
  15. )
  16. # 注意上面的inputs已经将每个词根据词表转换为数字了
  17. input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
  18. # 这里的attention_mask和input_mask是一个意思,先初始化为每个位置都是1
  19. attention_mask = [1] * len(input_ids)
  20. # 下面进行长度补全或截断
  21. input_ids = padding(input_ids, max_length, pad_token=0)
  22. attention_mask = padding(attention_mask, max_length, pad_token=0)
  23. token_type_ids = padding(token_type_ids, max_length, pad_token=0)
  24. label = triple['label']
  25. res.append((input_ids, attention_mask, token_type_ids, label))
  26. # 上面的(input_ids, attention_mask, token_type_ids, label)元组就是Bert的输入了
  27. # 下面把所有数据转换为tensor形式,并且让代表单词的数字确定为int型
  28. all_input_ids = torch.tensor([x[0] for x in res], dtype=torch.int64, device=device)
  29. all_attention_mask = torch.tensor([x[1] for x in res], dtype=torch.int64, device=device)
  30. all_token_type_ids = torch.tensor([x[2] for x in res], dtype=torch.int64, device=device)
  31. all_labels = torch.tensor([x[3] for x in res], dtype=torch.int64, device=device)
  32. # 将Bert的输入用pytorch的工具打包
  33. return torch.utils.data.TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)

二、模型搭建

  1. import transformers
  2. import torch
  3. from transformers import BertPreTrainedModel, BertModel
  4. from torch import nn
  5. import warnings
  6. warnings.filterwarnings('ignore')
  7. device = torch.device('cuda:0')
  8. """
  9. 使用bert来计算两个语句之间的匹配度,将问题和答案作为一个序列输入后,最后全连接层输出的特征向量转换为1维的分类输出,来判断是否匹配
  10. """
  11. # config为预训练模型的参数
  12. config = transformers.BertConfig.from_pretrained('./datasets/models/bert-pretrain')
  13. # 可以尝试用requires_grad冻结bert部分的参数,在Fine-tunning时只训练全连接的参数,这样训练会快一些
  14. class BertQA(BertPreTrainedModel):
  15. def __init__(self, config):
  16. super(BertQA, self).__init__(config)
  17. self.num_labels = config.num_labels # 分类数
  18. self.bert = BertModel(config) # BertModel也是transformers库中的一个类
  19. # 冻结bert参数,只fine-tuning后面层的参数
  20. for p in self.parameters(): # 应该是因为继承了父类,所以直接self.parameters()就好了
  21. p.requires_grad = False
  22. # Linear是将CLS的输出的最后一个维度变成2,这是一个二分类问题
  23. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  24. self.loss_fn = nn.CrossEntropyLoss(reduction='mean')
  25. self.init_weights() # 初始化全连接层的权重
  26. def forward(self, input_ids=None, attention_mask=None,
  27. token_type_ids=None, position_ids=None,
  28. head_mask=None, inputs_embeds=None, labels=None):
  29. # 将数据输入Bert模型,得到Bert的输出
  30. outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
  31. position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds)
  32. # 注意bert的输出是个元组,是以下四个输出
  33. """
  34. last_hidden_state:shape是(batch_size, sequence_length, hidden_size),hidden_size=768,它是模型最后一层输出的隐藏状态
  35. pooler_output:shape是(batch_size, hidden_size),这是序列的第一个token(classification token)的最后一层的隐藏状态,它是由线性层和Tanh激活函数进一步处理的,这个输出不是对输入的语义内容的一个很好的总结,对于整个输入序列的隐藏状态序列的平均化或池化通常更好。
  36. hidden_states:这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
  37. attentions:这也是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值
  38. """
  39. # 通过全连接网络,将特征转换为一个二维向量,可以看做标签0和1的得分情况
  40. # output[0]就是last_hidden_state, 然后[:, 0, :]的意思是每个样本只取第一个位置,也就是CLS的输出,然后用squeeze在中间加一个维度支撑
  41. logits = self.qa_outputs(outputs[0][:, 0, :]).squeeze() # 这里只选取第一列是只把第一列当做输出
  42. # 选择得分大的标签作为预测值
  43. predicted = nn.functional.softmax(logits, dim=-1)
  44. if labels is not None:
  45. loss = self.loss_fn(predicted, labels)
  46. return loss, predicted
  47. else:
  48. return predicted
  49. #模型实例化
  50. def model_real():
  51. model = BertQA.from_pretrained('./datasets/models/bert-pretrain', config=config)
  52. model.to(device)
  53. return model

三、模型训练

  1. #Dataset
  2. train_dataset = pre.tokenize(train_file, max_length, device=device)
  3. train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
  4. # 使用filter选择模型中未冻结的层,就可以只训练全连接层,而冻结bert层
  5. optimizer = transformers.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, eps=adam_epsilon)
  6. train_loss = []
  7. for epoch in range(epoch):
  8. print("Training epoch {}".format(epoch+1))
  9. for step, batch in enumerate(train_dataloader):
  10. model.train()
  11. model.zero_grad()
  12. inputs = {
  13. 'input_ids': batch[0],
  14. 'attention_mask': batch[1],
  15. 'token_type_ids': batch[2],
  16. 'labels': batch[3]
  17. }
  18. outputs = model(**inputs) # 如果使用**前缀,多余的参数会被认为是字典
  19. loss, results = outputs
  20. loss.backward()
  21. optimizer.step()
  22. train_loss.append(loss.item())

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/356560
推荐阅读
相关标签
  

闽ICP备14008679号