至于摘要生成过程中存在的幻觉问题,如内在的无中生有,外在的无中生有,有一篇很好的综述:Survey of Hallucination in Natural Language Generation:https://arxiv.org/pdf/2202.03629.pdf


1.1 分词还是不分词?


1.2 词表是以字为基础还是词为基础?统一标准


例如,以字为词表的模型(典型的就是中文BART)在transformers中生成的摘要的这种形式的:['我 是 生 成 的 摘 要']

而以词为词表的模型(典型的就是中文T5 Pegasus)生成的摘要是这种形式的:['我 是 生成摘要']


decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]
decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]


比如:1234举起手啊--->分字就是1 2 3 4 举 起 手 啊;按此表就是1234 举 起 手 啊 

1.3 计算指标



decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]
decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]


['我 是 生成摘要']

['我 是 参考 的 摘要']

二者格式相同,直接计算即可:result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)


decoded_preds = [" ".join(pred.replace(" ", "")) for pred in decoded_preds]
decoded_labels = [" ".join(label.replace(" ", "")) for label in decoded_labels]

result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)


['我 是 生 成摘 要']

['我 是 参 考 的 摘 要']

有空格是因为rouge库计算需要空格隔开,如果你用lawrouge库就不用了,直接调用result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)

1.4 计算差异


  1. from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
  2. from rouge import Rouge
  3. def compute_metrics(eval_pred):
  4. predictions, labels = eval_pred
  5. decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  6. # Replace -100 in the labels as we can't decode them.
  7. labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  8. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  9. # 字符级别
  10. decoded_preds = [" ".join((pred.replace(" ", ""))) for pred in decoded_preds]
  11. decoded_labels = [" ".join((label.replace(" ", ""))) for label in decoded_labels]
  12. # 词级别,分词
  13. # decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]
  14. # decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]
  15. rouge = Rouge()
  16. labels_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in labels]
  17. total = 0
  18. rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
  19. for decoded_label, decoded_pred in zip(decoded_labels, decoded_preds):
  20. total += 1
  21. scores = rouge.get_scores(hyps=decoded_pred, refs=decoded_label)
  22. rouge_1 += scores[0]['rouge-1']['f']
  23. rouge_2 += scores[0]['rouge-2']['f']
  24. rouge_l += scores[0]['rouge-l']['f']
  25. bleu += sentence_bleu(
  26. references=[decoded_label.split(' ')],
  27. hypothesis=decoded_pred.split(' '),
  28. smoothing_function=SmoothingFunction().method1
  29. )
  30. bleu /= len(decoded_labels)
  31. rouge_1 /= total
  32. rouge_2 /= total
  33. rouge_l /= total
  34. result = {'rouge-1': rouge_1, 'rouge-2': rouge_2, 'rouge-l': rouge_l}
  35. print(result)
  36. # 测试平均与分别计算是否一致
  37. result2 = rouge.get_scores(decoded_preds, decoded_labels, avg=True)
  38. print(result2)
  39. print(bleu)
  40. # result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}
  41. result = {key: value * 100 for key, value in result.items()}
  42. result["gen_len"] = np.mean(labels_lens)
  43. result["bleu"] = bleu * 100
  44. return result


流程比较复杂,例如基于QA的,需要分别训练Question generation与Answer generation模型,模型的训练好坏直接影响效果。先简单介绍QA与QG的训练,其中QA基于BERT,QG基于BART,这里是用的是英文的SQuAD-1.1,中文方法是一样的,使用CMRC2018的SQuAD格式数据,模型换成中文模型就好了。不同论文的实现方式不一样,我只说一个最最简单的方法。


2.2 数据加载



  17. import json
  18. import datasets
  19. from datasets.tasks import QuestionAnsweringExtractive
  20. logger = datasets.logging.get_logger(__name__)
  21. _CITATION = """\
  22. @article{2016arXiv160605250R,
  23. author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},
  24. Konstantin and {Liang}, Percy},
  25. title = "{SQuAD: 100,000+ Questions for Machine Comprehension of Text}",
  26. journal = {arXiv e-prints},
  27. year = 2016,
  28. eid = {arXiv:1606.05250},
  29. pages = {arXiv:1606.05250},
  30. archivePrefix = {arXiv},
  31. eprint = {1606.05250},
  32. }
  33. """
  34. _DESCRIPTION = """\
  35. Stanford Question Answering Dataset (SQuAD) is a reading comprehension \
  36. dataset, consisting of questions posed by crowdworkers on a set of Wikipedia \
  37. articles, where the answer to every question is a segment of text, or span, \
  38. from the corresponding reading passage, or the question might be unanswerable.
  39. """
  40. _URL = r"E:\Project\NLP\dataset\SQuAD-1.1 datasets/"
  41. _URLS = {
  42. "train": _URL + "train-v1.1.json",
  43. "dev": _URL + "dev-v1.1.json",
  44. }
  45. class SquadConfig(datasets.BuilderConfig):
  46. """BuilderConfig for SQUAD."""
  47. def __init__(self, **kwargs):
  48. """BuilderConfig for SQUAD.
  49. Args:
  50. **kwargs: keyword arguments forwarded to super.
  51. """
  52. super(SquadConfig, self).__init__(**kwargs)
  53. class Squad(datasets.GeneratorBasedBuilder):
  54. """SQUAD: The Stanford Question Answering Dataset. Version 1.1."""
  56. SquadConfig(
  57. name="plain_text",
  58. version=datasets.Version("1.0.0", ""),
  59. description="Plain text",
  60. ),
  61. ]
  62. def _info(self):
  63. return datasets.DatasetInfo(
  64. description=_DESCRIPTION,
  65. features=datasets.Features(
  66. {
  67. "id": datasets.Value("string"),
  68. "title": datasets.Value("string"),
  69. "context": datasets.Value("string"),
  70. "question": datasets.Value("string"),
  71. "answers": datasets.features.Sequence(
  72. {
  73. "text": datasets.Value("string"),
  74. "answer_start": datasets.Value("int32"),
  75. }
  76. ),
  77. }
  78. ),
  79. # No default supervised_keys (as we have to pass both question
  80. # and context as input).
  81. supervised_keys=None,
  82. homepage="https://rajpurkar.github.io/SQuAD-explorer/",
  83. citation=_CITATION,
  84. task_templates=[
  85. QuestionAnsweringExtractive(
  86. question_column="question", context_column="context", answers_column="answers"
  87. )
  88. ],
  89. )
  90. def _split_generators(self, dl_manager):
  91. downloaded_files = dl_manager.download_and_extract(_URLS)
  92. return [
  93. datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
  94. datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}),
  95. ]
  96. def _generate_examples(self, filepath):
  97. """This function returns the examples in the raw (text) form."""
  98. logger.info("generating examples from = %s", filepath)
  99. key = 0
  100. with open(filepath, encoding="utf-8") as f:
  101. squad = json.load(f)
  102. for article in squad["data"]:
  103. title = article.get("title", "")
  104. for paragraph in article["paragraphs"]:
  105. context = paragraph["context"] # do not strip leading blank spaces GH-2585
  106. for qa in paragraph["qas"]:
  107. answer_starts = [answer["answer_start"] for answer in qa["answers"]]
  108. answers = [answer["text"] for answer in qa["answers"]]
  109. # Features currently used are "context", "question", and "answers".
  110. # Others are extracted here for the ease of future expansions.
  111. yield key, {
  112. "title": title,
  113. "context": context,
  114. "question": qa["question"],
  115. "id": qa["id"],
  116. "answers": {
  117. "answer_start": answer_starts,
  118. "text": answers,
  119. },
  120. }
  121. key += 1

2.3 QA训练


  1. # coding=utf-8
  2. import json
  3. import numpy as np
  4. import torch
  5. from datasets import Dataset,load_dataset
  6. from transformers.data.metrics.squad_metrics import compute_exact, compute_f1
  7. squad = load_dataset('./squad.py')
  8. squad2 = squad["validation"][0]
  9. # 获取验证数据集
  10. # valid_datasets = squad["validation"].flatten().data
  11. # # 获取验证数据集中的content
  12. # valid_contents = valid_datasets[2]
  13. # # 获取验证数据集中的gold answer
  14. # valid_answers = valid_datasets[4]
  15. # # 获取验证数据集中的gold answer start
  16. # valid_answers_start = valid_datasets[5]
  17. xx = compute_f1("left Graz and severed all relations with his family","left Graz and severed")
  18. print(squad)
  19. from transformers import AutoTokenizer, default_data_collator, BertForQuestionAnswering, TrainingArguments, Trainer, \
  20. BertTokenizer
  21. tokenizer = AutoTokenizer.from_pretrained(r"E:\Project\NLP\bert-base-uncased")
  22. def preprocess_function(examples):
  23. '''
  24. 用于处理训练集,因为训练集每个问题只有一个参考答案回答
  25. '''
  26. questions = [q.strip() for q in examples["question"]]
  27. inputs = tokenizer(
  28. questions,
  29. examples["context"],
  30. max_length=512,
  31. truncation="only_second",
  32. return_offsets_mapping=True,
  33. padding="max_length",
  34. )
  35. offset_mapping = inputs.pop("offset_mapping")
  36. answers = examples["answers"]
  37. start_positions = []
  38. end_positions = []
  39. for i, offset in enumerate(offset_mapping):
  40. answer = answers[i]
  41. start_char = answer["answer_start"][0]
  42. end_char = answer["answer_start"][0] + len(answer["text"][0])
  43. sequence_ids = inputs.sequence_ids(i)
  44. # Find the start and end of the context
  45. idx = 0
  46. while sequence_ids[idx] != 1:
  47. idx += 1
  48. context_start = idx
  49. while sequence_ids[idx] == 1:
  50. idx += 1
  51. context_end = idx - 1
  52. # If the answer is not fully inside the context, label it (0, 0)
  53. if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
  54. start_positions.append(0)
  55. end_positions.append(0)
  56. else:
  57. # Otherwise it's the start and end token positions
  58. idx = context_start
  59. while idx <= context_end and offset[idx][0] <= start_char:
  60. idx += 1
  61. start_positions.append(idx - 1)
  62. idx = context_end
  63. while idx >= context_start and offset[idx][1] >= end_char:
  64. idx -= 1
  65. end_positions.append(idx + 1)
  66. inputs["start_positions"] = start_positions
  67. inputs["end_positions"] = end_positions
  68. return inputs
  69. tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)
  70. # train_x = squad["train"].map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)
  71. # valid_x = squad["validation"].map(preprocess_function, batched=True, remove_columns=squad["validation"].column_names)
  72. data_collator = default_data_collator
  73. model = BertForQuestionAnswering.from_pretrained(r"E:\Project\NLP\long-document\bert-base-uncased")
  74. training_args = TrainingArguments(
  75. # fp16 = True,
  76. output_dir="./QA_results",
  77. do_train=True,
  78. do_eval=True,
  79. evaluation_strategy="epoch",
  80. # eval_steps=2,
  81. learning_rate=1e-4,
  82. per_device_train_batch_size=1,
  83. per_device_eval_batch_size=1,
  84. logging_dir="logs",
  85. logging_strategy="steps",
  86. save_total_limit=3,
  87. logging_steps=1,
  88. num_train_epochs=4,
  89. weight_decay=0.01,
  90. gradient_accumulation_steps=8,
  91. )
  92. def compute_metrics(eval_pred):
  93. predictions,label_ids = eval_pred
  94. start = predictions[0]
  95. end = predictions[1]
  96. answer_start = np.argmax(predictions[0],axis = 1)
  97. answer_end = np.argmax(predictions[1],axis = 1)
  98. label_start = label_ids[0] # 这个是token过后的开始与结束为止 不是一个一个字符数的 是一个一个单词数的
  99. label_end = label_ids[1]
  100. data = tokenized_squad["validation"]
  101. gold_answers=[]
  102. pred_answers=[]
  103. # 遍历每一个验证数据
  104. for idx,example in enumerate(data):
  105. input_ids = example["input_ids"]# 取出文本
  106. label_start = example["start_positions"]# 取出开始
  107. label_end = example["end_positions"]# 取出结束
  108. gold_answer=""
  109. pred_answer=""
  110. for i in range(label_end-label_start+1):
  111. gold_answer+=str(input_ids[label_start+i])+" "
  112. if answer_start[idx] < answer_end[idx]:
  113. answer_end[idx] = answer_start[idx]
  114. for i in range(answer_end[idx]-answer_start[idx]+1):
  115. pred_answer+=str(input_ids[label_start+i])+" "
  116. gold_answers.append(gold_answer.strip())
  117. pred_answers.append(pred_answer.strip())
  118. # 计算f1 score与exact score
  119. f1_score=0
  120. exact_score=0
  121. for gold_answer,pred_answer in zip(gold_answers,pred_answers):
  122. f1_score+=compute_f1(gold_answer,pred_answer)
  123. exact_score+=compute_exact(gold_answer,pred_answer)
  124. f1_score/=len(gold_answers)
  125. exact_score/=len(gold_answers)
  126. f1_score*=100
  127. exact_score*=100
  128. result = {'f1_score': f1_score, 'exact_score': exact_score}
  129. return result
  130. trainer = Trainer(
  131. model=model,
  132. args=training_args,
  133. train_dataset=tokenized_squad["train"],
  134. eval_dataset=tokenized_squad["validation"],
  135. data_collator=data_collator,
  136. tokenizer=tokenizer,
  137. compute_metrics=compute_metrics
  138. )
  139. trainer.train()

 2.4 QG训练


  1. import json
  2. import numpy as np
  3. from datasets import Dataset
  4. from transformers import AutoTokenizer, BartTokenizer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, \
  5. BartForConditionalGeneration, Seq2SeqTrainer
  6. from transformers.data.metrics.squad_metrics import compute_f1
  7. # x = ['Which NFL team represented the AFC at Super Bowl 50?', 'Which NFL team represented the NFC at Super Bowl 50?', 'Where did Super Bowl 50 take place?', 'Which NFL team won Super Bowl 50?', 'What color was used to emphasize the 50th anniversary of the Super Bowl?', 'What was the theme of Super Bowl 50?', 'What day was the game played on?', 'What is the AFC short for?', 'What was the theme of Super Bowl 50?', 'What does AFC stand for?', 'What day was the Super Bowl played on?', 'Who won Super Bowl 50?', 'What venue did Super Bowl 50 take place in?', 'What city did Super Bowl 50 take place in?', 'If Roman numerals were used, what would Super Bowl 50 have been called?', 'Super Bowl 50 decided the NFL champion for what season?', 'What year did the Denver Broncos secure a Super Bowl title for the third time?', 'What city did Super Bowl 50 take place in?', 'What stadium did Super Bowl 50 take place in?', 'What was the final score of Super Bowl 50? ']
  8. # xx = ['Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League']
  9. max_input_length=512
  10. max_target_length=128
  11. train_path = r'E:\Project\NLP\dataset\SQuAD-1.1 datasets\train-v1.1.json'
  12. dev_path = r'E:\Project\NLP\dataset\SQuAD-1.1 datasets\dev-v1.1.json'
  13. output_dir=r'E:\Project\NLP\dataset\SQuAD-1.1 datasets\QG_results'
  14. tokenizer_path=r'E:\Project\NLP\bart-base-english'
  15. model_path=r'E:\Project\NLP\bart-base-english'
  16. def data_preprocess(path):
  17. with open(path, 'r', encoding='utf-8') as f_train:
  18. train_set = json.load(f_train)
  19. datas = train_set
  20. # convert
  21. new_data = []
  22. for data in datas["data"]:
  23. for d in data['paragraphs']:
  24. context = d['context']
  25. for qa in d['qas']:
  26. new_data.append({
  27. 'context': context,
  28. 'answers': qa['answers'],
  29. 'question': qa['question']
  30. })
  31. contexts=[]
  32. labels=[]
  33. for data in new_data:
  34. answer_text = data['answers'][0]['text']
  35. answer_len = len(answer_text)
  36. answer_start = data['answers'][0]['answer_start']
  37. hl_context = data['context'][:answer_start] +'<hl>' + answer_text + '<hl>' + data['context'][answer_start + answer_len:]
  38. label=data['question'] #+ '</s>'
  39. contexts.append(hl_context)
  40. labels.append(label)
  41. return contexts, labels
  42. train_contexts,train_labels=data_preprocess(train_path)
  43. dev_contexts,dev_labels=data_preprocess(dev_path)
  44. train={}
  45. dev={}
  46. train["contexts"]=train_contexts#[0:50]
  47. train["labels"]=train_labels#[0:50]
  48. dev["contexts"]=dev_contexts#[0:20]
  49. dev["labels"]=dev_labels#[0:20]
  50. train_dataset=Dataset.from_dict(train)
  51. train_dataset=train_dataset.shuffle(seed=42)
  52. dev_dataset=Dataset.from_dict(dev)
  53. dev_dataset=dev_dataset.shuffle(seed=42)
  54. tokenizer = BartTokenizer.from_pretrained(tokenizer_path)
  55. special_tokens_dict = {'additional_special_tokens': ['<hl>']}
  56. tokenizer.add_special_tokens(special_tokens_dict)
  57. def preprocess_function(examples):
  58. inputs = [doc for doc in examples["contexts"]]
  59. model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
  60. # Setup the tokenizer for targets
  61. with tokenizer.as_target_tokenizer():
  62. labels = tokenizer(examples["labels"], max_length=max_target_length, truncation=True)
  63. # title_len_1 = tokenizer(examples["len_title_1"], max_length=max_target_length, truncation=True)
  64. # title_len_all = tokenizer(examples["len_title_all"], max_length=max_target_length, truncation=True)
  65. model_inputs["labels"] = labels["input_ids"]
  66. return model_inputs
  67. tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
  68. tokenized_dev_dataset = dev_dataset.map(preprocess_function, batched=True, remove_columns=dev_dataset.column_names)
  69. batch_size = 1
  70. args = Seq2SeqTrainingArguments(
  71. fp16 = True,
  72. output_dir=output_dir,
  73. num_train_epochs=5, # demo
  74. do_train=True,
  75. do_eval=True,
  76. per_device_train_batch_size=1, # demo
  77. per_device_eval_batch_size=1,
  78. learning_rate=1e-04,
  79. warmup_steps=100,
  80. weight_decay=0.01,
  81. label_smoothing_factor=0.1,
  82. predict_with_generate=True,
  83. logging_dir="logs",
  84. logging_strategy="steps",
  85. logging_steps=1,
  86. save_total_limit=3,
  87. evaluation_strategy="epoch",
  88. generation_max_length=max_target_length,
  89. generation_num_beams=4,
  90. # remove_unused_columns=False,
  91. )
  92. model = BartForConditionalGeneration.from_pretrained(model_path)
  93. model.resize_token_embeddings(len(tokenizer))
  94. # model, list_en, list_de = create_student_by_copying_alternating_layers(model, 'trian.pth', 12, 3)
  95. data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
  96. def compute_metrics(eval_pred):
  97. predictions, labels = eval_pred
  98. decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  99. # Replace -100 in the labels as we can't decode them.
  100. labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  101. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  102. f1_score = 0
  103. for label, pred in zip(decoded_labels, decoded_preds):
  104. f1_score += compute_f1(label, pred)
  105. f1_score /= len(decoded_preds)
  106. f1_score *= 100
  107. result = {'f1_score': f1_score}
  108. return result
  109. trainer = Seq2SeqTrainer(
  110. model,
  111. args,
  112. train_dataset=tokenized_train_dataset,
  113. # train_dataset=dataset_train,
  114. eval_dataset=tokenized_dev_dataset,
  115. # eval_dataset=dataset_valid,
  116. data_collator=data_collator,
  117. tokenizer=tokenizer,
  118. compute_metrics=compute_metrics,
  119. )
  120. train_result = trainer.train()

2.5 使用


3 BERTScore


  1. from bert_score import score
  2. # data
  3. cands = ['天天干家务烦死了','难受死了啊']
  4. refs = ['这也完全不相干啊','真的难受死了啊']
  5. P, R, F1 = score(cands, refs, lang="zh", verbose=True)
  6. print(f"System level F1 score: {F1.mean():.3f}")

