赞
踩
本文为学习Datawhale 2021.8组队学习NLP入门之Transformer笔记
原学习文档地址:https://github.com/datawhalechina/learn-nlp-with-transformers
任务:抽取式问答
数据集:squad
三个key:“context", "question"和“answers”
# 展示训练集的第一个句子
datasets["train"][0]
{
'id': '5733be284776f41900661182',
'title': 'University_of_Notre_Dame',
'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
'answers': {
'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}
answers保存答案的开始位置和整个答案内容
依旧是colab下载后本地读入
from datasets import load_from_disk
datasets = load_from_disk("E:/jupyter_notebook/0_learn-nlp-with-transformers-main/docs/篇章4-使用Transformers解决NLP任务/datasets/squad")
定义
squad_v2 = False
model_checkpoint = "distilbert-base-uncased"
batch_size = 16
from transformers import AutoTokenizer
import transformers
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
抽取式问答的数据预处理有几个点
一是context可能会很长,如何处理长度超过max_length的文本?
二是tokenizer以后,start_label和end_label要重新定位
三是切割文本以后,label标记可能又会出现问题
def prepare_train_features(examples): # 既要对examples进行truncation(截断)和padding(补全)还要还要保留所有信息,所以要用的切片的方法。 # 每一个一个超长文本example会被切片成多个输入,相邻两个输入之间会有交集。 tokenized_examples = tokenizer( examples["question" if pad_on_right else "context"], examples["context" if pad_on_right else "question"], truncation="only_second" if pad_on_right else "only_first", # 如果context是拼接在question后面的,对应着第2个文本,所以使用only_second控制 max_length=max_length, stride=doc_stride, # tokenizer使用doc_stride控制切片之间的重合长度 return_overflowing_tokens=True, return_offsets_mapping=True, # 可以得到token对应原context中的位置 padding="max_length", ) # !!这里用的是pop方法 # 我们使用overflow_to_sample_mapping参数来映射切片片ID到原始ID。 # 比如有2个expamples被切成4片,那么对应是[0, 0, 1, 1],前两片对应原来的第一个example。 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # offset_mapping也对应4片 # offset_mapping参数帮助我们映射到原始输入,由于答案标注在原始输入上,所以有助于我们找到答案的起始和结束位置。 offset_mapping = tokenized_examples.pop("offset_mapping") # 重新标注数据 tokenized_examples["start_positions"] = [] tokenized_examples["end_positions"] = [] for i, offsets in enumerate(offset_mapping): # i就是第几个句子,offsets是一个存储这个句子每个token在原context中对应位置的列表 # 对每一片进行处理 # 将无答案的样本标注到CLS上 input_ids = tokenized_examples["input_ids"][i] # 得到这个句子的输入token cls_index = input_ids.index(tokenizer.cls_token_id) # 找到CLS也就是token为101的位置 = 0 # 区分question和context sequence_ids = tokenized_examples.sequence_ids(i) # None,0,1分别标注为特殊符号,第一个句子和第二个句子 # 拿到原始的example 下标. sample_index = sample_mapping[i] # 第i个切片对应的原context标号 answers = examples["answers"][sample_index] # 得到该切片对应的原context的answer # 如果没有答案,则使用CLS所在的位置为答案. if len(answers["answer_start"]) == 0: # 感觉这里就需要看具体的数据集标注了,这里认为没有答案的数据集answer_start什么都没存 tokenized_examples["start_positions"].append(cls_index) # 没有答案就标注头尾都在CLS tokenized_examples["end_positions"].append(cls_index) else: # 答案的character级别Start/end位置. start_char = answers["answer_start"][0] end_char = start_char + len(answers["text"][0]) # 找到token级别的index start. token_start_index = 0 # sequence_ids就是存0,1,None的一个区分句子的列表 while sequence_ids[token_start_index] != (1 if pad_on_right else 0): token_start_index += 1 # 找到token级别的index end. token_end_index = len(input_ids) - 1 # 这里的输入经过tokenizer以后,长度全变成最大长度384了,input_ids没有做填充,后面都是0,0对应的label也是None # 所以0肯定不会代表token while sequence_ids[token_end_index] != (1 if pad_on_right else 0): token_end_index -= 1 # 检测答案是否超出文本长度,超出的话也适用CLS index作为标注. if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。