当前位置:   article > 正文

Datawhale组队学习NLP_Bert抽取式问答学习笔记_distilbert-base-uncased 问答

distilbert-base-uncased 问答

本文为学习Datawhale 2021.8组队学习NLP入门之Transformer笔记
原学习文档地址:https://github.com/datawhalechina/learn-nlp-with-transformers

任务:抽取式问答
数据集:squad
三个key:“context", "question"和“answers”

# 展示训练集的第一个句子
datasets["train"][0]
{
   'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {
   'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

answers保存答案的开始位置和整个答案内容

1 数据读入

依旧是colab下载后本地读入

from datasets import load_from_disk

datasets = load_from_disk("E:/jupyter_notebook/0_learn-nlp-with-transformers-main/docs/篇章4-使用Transformers解决NLP任务/datasets/squad")
  • 1
  • 2
  • 3

2 数据预处理

定义

squad_v2 = False
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

from transformers import AutoTokenizer
import transformers

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

抽取式问答的数据预处理有几个点

一是context可能会很长,如何处理长度超过max_length的文本?
二是tokenizer以后,start_label和end_label要重新定位
三是切割文本以后,label标记可能又会出现问题

def prepare_train_features(examples):
    # 既要对examples进行truncation(截断)和padding(补全)还要还要保留所有信息,所以要用的切片的方法。
    # 每一个一个超长文本example会被切片成多个输入,相邻两个输入之间会有交集。
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",  # 如果context是拼接在question后面的,对应着第2个文本,所以使用only_second控制
        max_length=max_length,
        stride=doc_stride,  # tokenizer使用doc_stride控制切片之间的重合长度
        return_overflowing_tokens=True,
        return_offsets_mapping=True,  # 可以得到token对应原context中的位置
        padding="max_length",
    )
    # !!这里用的是pop方法
    # 我们使用overflow_to_sample_mapping参数来映射切片片ID到原始ID。
    # 比如有2个expamples被切成4片,那么对应是[0, 0, 1, 1],前两片对应原来的第一个example。
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # offset_mapping也对应4片
    # offset_mapping参数帮助我们映射到原始输入,由于答案标注在原始输入上,所以有助于我们找到答案的起始和结束位置。
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # 重新标注数据
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):  # i就是第几个句子,offsets是一个存储这个句子每个token在原context中对应位置的列表
        # 对每一片进行处理
        # 将无答案的样本标注到CLS上
        input_ids = tokenized_examples["input_ids"][i]  # 得到这个句子的输入token
        cls_index = input_ids.index(tokenizer.cls_token_id)  # 找到CLS也就是token为101的位置 = 0

        # 区分question和context
        sequence_ids = tokenized_examples.sequence_ids(i)  # None,0,1分别标注为特殊符号,第一个句子和第二个句子

        # 拿到原始的example 下标.
        sample_index = sample_mapping[i]  # 第i个切片对应的原context标号
        answers = examples["answers"][sample_index]  # 得到该切片对应的原context的answer
        # 如果没有答案,则使用CLS所在的位置为答案.
        if len(answers["answer_start"]) == 0:  # 感觉这里就需要看具体的数据集标注了,这里认为没有答案的数据集answer_start什么都没存
            tokenized_examples["start_positions"].append(cls_index) # 没有答案就标注头尾都在CLS
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # 答案的character级别Start/end位置.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # 找到token级别的index start.
            token_start_index = 0
            # sequence_ids就是存0,1,None的一个区分句子的列表
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # 找到token级别的index end.
            token_end_index = len(input_ids) - 1
            # 这里的输入经过tokenizer以后,长度全变成最大长度384了,input_ids没有做填充,后面都是0,0对应的label也是None
            # 所以0肯定不会代表token
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # 检测答案是否超出文本长度,超出的话也适用CLS index作为标注.
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/347523
推荐阅读
相关标签
  

闽ICP备14008679号