当前位置:   article > 正文

问答系统案例----基于Bert实现知识库问答_this is not expected if you are initializing bertm

this is not expected if you are initializing bertmodel from thte checkpoint

问答系统案例----基于Bert实现知识库问答

基于Transformers.Trainer实现

任务描述:

知识库问答也叫做知识图谱问答,模型结合知识图谱,对输入的问题进行推理和查询从而得到正确答案的一项综合性任务。知识图谱问答方法可分为两大类,一种是基于信息检索的方式,一种是基于语义解析的方式。信息检索的方式不需要生成中间结果,直接得到问题答案,十分简洁,但是对复杂问题的处理能力有限。语义解析的方式需要对输入的自然语言问题进行语义解析,再进行推理,具备解决复杂问题的能力。本教程选用信息检索的方式进行讨论。

1.数据集:

使用开放式问答数据集WikiQA.WikiQA使用Bing查询日志作为问题源,每个问题都链接到一个可能有答案的维基百科页面,页面的摘要部分提供了关于这个问题的重要信息,WikiQA使用其中的句子作为问题的候选答案。数据集中共包括3047个问题和29258个句子。

# 统一导入工具包
import pandas as pd
import csv
import transformers
import torch
from torch.utils.data.dataset import T_co
from transformers import BertPreTrainedModel, BertModel, BertTokenizer
from torch import nn
import numpy as np
import os
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import warnings

warnings.filterwarnings('ignore')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
2022-04-10 21:51:53.682138: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
  • 1

2.数据准备

WikiQA问答数据集可以用于问答系统的训练。数据集中存放着问题的文本,每个问题对应的知识库数据,以及对应的答案。因此我们要将数据读入内存中,将不同类型的数据分别处理,将数据结构化。整体步骤如下:
(1) 数据说明
(2) 数据加载
(3) 数据标准化:向量化、对齐、掩码

pd_table = pd.read_csv("data/wiki/WikiQASent-dev.txt", sep="\t", header=None)
  • 1

3.数据加载

本数据集的知识库就是通过问题检索到的文档摘要,而摘要中的每一句话都作为候选答案。因此我们可以将问答问题转化为两个句子之间的匹配问题。为了后续模型的训练,我们将数据加载为<question,answer,label>这样的三元组。如果answer是question的正确答案,则label为1,反之则为0.每一个三元组用一个字典来存储。

定义load函数。使用csv将文件读入,在csv.reader中指定’\t’作为分隔符(delimiter),将数据自动分割。依次遍历每一行,将数据按照上述数据结构加载

def load(filename):
    result = []
    with open(filename, 'r', encoding='utf-8') as csvfile:
        spamreader = pd.read_csv(filename, sep="\t", header=None)
        for i in range(len(spamreader)):
            row = spamreader.iloc[i]
            res = {}
            res['question'] = str(row[0])
            res['answer'] = str(row[1])
            res['label'] = int(row[2])
            if res['question'] == "" or res['answer'] == "" or res['label'] == None:
                continue
            result.append(res)
    return result
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
train_file = load('data/wiki/WikiQASent-train.txt')
valid_file = load('data/wiki/WikiQASent-dev.txt')
test_file = load('data/wiki/WikiQASent-test.txt')
  • 1
  • 2
  • 3
# valid_file
  • 1
len(train_file), len(valid_file), len(test_file)
  • 1
(20360, 2733, 6165)
  • 1

4.数据token化并获取数据

接下来需要将数据处理为Bert的标准输入形式。Bert的输入主要由input_ids,attention_mask,token_type_ids三部分构成。

tokenize = BertTokenizer.from_pretrained("bert-base-uncased")
  • 1


def get_data(name_file):
    input_ids = []
    attention_mask = []
    token_type_ids = []
    labels = []

    for i, dic in enumerate(name_file):
        question = dic["question"]
        answer = dic["answer"]
        label = dic["label"]
        # output = tokenize.encode_plus(text=question, text_pair=answer, max_length=100, truncation=True, )
        # print(tokenize.decode(output["input_ids"]))
        output = tokenize.encode_plus(text=question, text_pair=answer, max_length=64, truncation=True,
                                      add_special_tokens=True, padding="max_length")
        # print(output)
        # print(tokenize.decode(output["input_ids"]))
        input_ids.append(output["input_ids"])
        attention_mask.append(output["attention_mask"])
        token_type_ids.append(output["token_type_ids"])
        labels.append(label)
        # print(input_ids)
    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    token_type_ids = torch.tensor(token_type_ids)
    labels = torch.tensor(labels)
    dic = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "labels": labels}
    return dic
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

5.模型构建

from transformers import BertPreTrainedModel, BertModel, BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained("bert-base-uncased")


class BertQA(BertPreTrainedModel):
    def __init__(self, config, freeze=True):
        super(BertQA, self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        # 冻结bert参数,只fine-tuning后面层的参数
        if freeze:
            for p in self.bert.parameters():
                p.requires_grad = False
        self.qa_ouputs = nn.Linear(config.hidden_size, 2)
        self.loss_fn = nn.CrossEntropyLoss()
        self.init_weights()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(input_ids, attention_mask, token_type_ids)
        logits = self.qa_ouputs(outputs[1])
        # 通过全连接网络,将特征转化为一个二维向量,可以看作标签0和1的得分情况
        predicted_labels = torch.softmax(logits, dim=-1)
        # 如果输入数据中含有标准答案,就计算loss值(即训练过程)
        if labels is not None:
            loss = self.loss_fn(predicted_labels, labels)
            return {"loss": loss, "predicted_labels": predicted_labels}
        # 否则返回预测值(测试过程)
        else:
            return {"predicted_labels": predicted_labels}


model = BertQA(config)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  • 1
  • 2
  • 3

6.训练模型

这里使用 transformers提供的trainer进行训练,此方法的通用性很高,可以在很多模型上使用,节约了书写训练过程的时间。


from transformers import Trainer, TrainingArguments
from datasets import Dataset
# 自定义评测指标
def compute_metrics(pred):
    # pred: ----> label_ids predictions
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = (labels==preds).sum()/len(labels)
    return {"acc":acc}

train_dataset = Dataset.from_dict(get_data(train_file))
eval_dataset = Dataset.from_dict(get_data(valid_file))
test_dataset = Dataset.from_dict(get_data(test_file))
# logging_steps  为展示 的 步  save_steps=10,
# evaluation_strategy="steps" 和 eval_steps=10 要同时使用才有效
args = TrainingArguments(output_dir="./result", gradient_accumulation_steps=10, learning_rate=1e-3,logging_dir="./logging",
                         num_train_epochs=2, per_device_train_batch_size=8,
                         logging_steps=100,eval_steps=100,evaluation_strategy="steps",seed=2022,save_steps=False,
                         per_device_eval_batch_size=8)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset,
                  compute_metrics=compute_metrics)
trainer.train()
# 这个用于测试传入测试集
trainer.evaluate(test_dataset)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
{'eval_loss': 0.3608052432537079,
 'eval_acc': 0.9524736415247365,
 'eval_runtime': 11.6196,
 'eval_samples_per_second': 530.57,
 'epoch': 2.0,
 'eval_mem_cpu_alloc_delta': 12288,
 'eval_mem_gpu_alloc_delta': 0,
 'eval_mem_cpu_peaked_delta': 0,
 'eval_mem_gpu_peaked_delta': 19019776}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/356582?site
推荐阅读
相关标签
  

闽ICP备14008679号