赞
踩
Datewhle29期__NLP之transformer :
学习资料地址:
https://datawhalechina.github.io/learn-nlp-with-transformers/#/
github地址:
https://github.com/datawhalechina/learn-nlp-with-transformers
datasets = load_dataset("squad_v2" if squad_v2 else "squad")
show_random_elements(datasets["train"], num_examples=2)
一般来说预训练模型输入有最大长度要求,所以我们通常将超长的输入进行截断。但是,如果我们将问答数据三元组<question, context, answer>中的超长context截断,那么我们可能丢掉答案(因为我们是从context中抽取出一个小片段作为答案)。为了解决这个问题,下面的代码找到一个超过长度的例子,然后向您演示如何进行处理。我们把超长的输入切片为多个较短的输入,每个输入都要满足模型最大长度输入要求。由于答案可能存在与切片的地方,因此我们需要允许相邻切片之间有交集,代码中通过doc_stride参数控制。
机器问答预训练模型通常将question和context拼接之后作为输入,然后让模型从context里寻找答案。
max_length = 384 # 输入feature的最大长度,question和context拼接之后
doc_stride = 128 # 2个切片之间的重合token数量
for循环遍历数据集,寻找一个超长样本→截断切片→input_ids还原为文本格式
tokenized_example = tokenizer(
example["question"],
example["context"],
max_length=max_length,
truncation="only_second",
return_overflowing_tokens=True,
return_offsets_mapping=True,
stride=doc_stride
)
# 打印切片前后位置下标的对应关系
print(tokenized_example["offset_mapping"][0][:100])
sequence_ids = tokenized_example.sequence_ids()
print(sequence_ids)
[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]
answers = example["answers"] start_char = answers["answer_start"][0] end_char = start_char + len(answers["text"][0]) # 找到当前文本的Start token index. token_start_index = 0 while sequence_ids[token_start_index] != 1: token_start_index += 1 # 找到当前文本的End token idnex. token_end_index = len(tokenized_example["input_ids"][0]) - 1 while sequence_ids[token_end_index] != 1: token_end_index -= 1 # 检测答案是否在文本区间的外部,这种情况下意味着该样本的数据标注在CLS token位置。 offsets = tokenized_example["offset_mapping"][0] if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): # 将token_start_index和token_end_index移动到answer所在位置的两侧. # 注意:答案在最末尾的边界条件. while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: token_start_index += 1 start_position = token_start_index - 1 while offsets[token_end_index][1] >= end_char: token_end_index -= 1 end_position = token_end_index + 1 print("start_position: {}, end_position: {}".format(start_position, end_position)) else: print("The answer is not in this feature.")
start_position: 23, end_position: 26
验证: 使用答案所在位置下标,取到对应的token ID,然后转化为文本,然后和原始答案进行但对比。
最后, 对数据集datasets里面的所有样本进行预处理,处理的方式是使用map函数,将预处理函数prepare_train_features应用到(map)所有样本上。
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
args = TrainingArguments(
f"test-squad",
evaluation_strategy = "epoch",
learning_rate=2e-5, #学习率
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=3, # 训练的论次
weight_decay=0.01,
)
n_best_size = 20 import numpy as np start_logits = output.start_logits[0].cpu().numpy() end_logits = output.end_logits[0].cpu().numpy() # 收集最佳的start和end logits的位置: start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() valid_answers = [] for start_index in start_indexes: for end_index in end_indexes: if start_index <= end_index: # 如果start小雨end,那么合理的 valid_answers.append( { "score": start_logits[start_index] + end_logits[end_index], "text": "" # 后续需要根据token的下标将答案找出来 } )
实现以上, 添加以下两个信息到validation的features里面:
后将prepare_validation_features函数应用到每个验证集合的样本上。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。