当前位置:   article > 正文

中文自动文本摘要生成指标计算,Rouge/Bleu/BertScore/QA代码实现_paragraph bleu

paragraph bleu

本部分讲述下如何计算生成摘要与参考摘要的指标,指标方面分为两类,一类基于n-grams计算,如Rouge-1,Rouge-2,Rouge-L,BLEU,主要衡量摘要的句法的连贯性,不能衡量生成摘要的真实性与忠诚程度,另一类基于蕴含或者QA等辅助手段,这种方式能够更好的衡量生成摘要的忠诚度,如FEQA,QuestEval,最后就是简单地BertScore。代码中均为transformers库中计算代码,放置位置参考之前的bart等文章。

至于摘要生成过程中存在的幻觉问题,如内在的无中生有,外在的无中生有,有一篇很好的综述:Survey of Hallucination in Natural Language Generation:https://arxiv.org/pdf/2202.03629.pdf

1 ROUGE、BLEU计算

1.1 分词还是不分词?

计算指标前需要一个统一的标准,不同与英文,中文指标再计算的时候各家有各家的计算方法,有的分词有的部分词,有的jieba分词有的hanlp分词。

1.2 词表是以字为基础还是词为基础?统一标准

不同的模型词表也不同,生成摘要的时候要统一标准!

例如,以字为词表的模型(典型的就是中文BART)在transformers中生成的摘要的这种形式的:['我 是 生 成 的 摘 要']

而以词为词表的模型(典型的就是中文T5 Pegasus)生成的摘要是这种形式的:['我 是 生成摘要']

那么在这种情况下,就要统一标准:将摘要首先去除空格,全部变成['我是生成的摘要']这种形式,即:

decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]
decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]

但这样会存在一个问题,对于数字英文等多的中文摘要指标计算会偏高。若是字的生成,不建议去除空格,transformers生成的原始摘要设置空格分开就行了。

比如:1234举起手啊--->分字就是1 2 3 4 举 起 手 啊;按此表就是1234 举 起 手 啊 

1.3 计算指标

计算指标此时就可以分为两种,一种是分词后计算,一种是不分词计算。

(1)分词后计算,以jieba分词为例:

decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]
decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]

此时结果为:

['我 是 生成摘要']

['我 是 参考 的 摘要']

二者格式相同,直接计算即可:result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)

(2)如果不分词的话,则直接按字符级计算:

decoded_preds = [" ".join(pred.replace(" ", "")) for pred in decoded_preds]
decoded_labels = [" ".join(label.replace(" ", "")) for label in decoded_labels]

result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)

此时结果为:

['我 是 生 成摘 要']

['我 是 参 考 的 摘 要']

有空格是因为rouge库计算需要空格隔开,如果你用lawrouge库就不用了,直接调用result = rouge.get_scores(decoded_preds, decoded_labels, avg=True)

1.4 计算差异

一般来说,二者之间有着巨大的差异,分词后计算指标通常比字符级小很多,5-10个点不等

  1. from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
  2. from rouge import Rouge
  3. def compute_metrics(eval_pred):
  4. predictions, labels = eval_pred
  5. decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  6. # Replace -100 in the labels as we can't decode them.
  7. labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  8. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  9. # 字符级别
  10. decoded_preds = [" ".join((pred.replace(" ", ""))) for pred in decoded_preds]
  11. decoded_labels = [" ".join((label.replace(" ", ""))) for label in decoded_labels]
  12. # 词级别,分词
  13. # decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]
  14. # decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]
  15. rouge = Rouge()
  16. labels_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in labels]
  17. total = 0
  18. rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
  19. for decoded_label, decoded_pred in zip(decoded_labels, decoded_preds):
  20. total += 1
  21. scores = rouge.get_scores(hyps=decoded_pred, refs=decoded_label)
  22. rouge_1 += scores[0]['rouge-1']['f']
  23. rouge_2 += scores[0]['rouge-2']['f']
  24. rouge_l += scores[0]['rouge-l']['f']
  25. bleu += sentence_bleu(
  26. references=[decoded_label.split(' ')],
  27. hypothesis=decoded_pred.split(' '),
  28. smoothing_function=SmoothingFunction().method1
  29. )
  30. bleu /= len(decoded_labels)
  31. rouge_1 /= total
  32. rouge_2 /= total
  33. rouge_l /= total
  34. result = {'rouge-1': rouge_1, 'rouge-2': rouge_2, 'rouge-l': rouge_l}
  35. print(result)
  36. # 测试平均与分别计算是否一致
  37. result2 = rouge.get_scores(decoded_preds, decoded_labels, avg=True)
  38. print(result2)
  39. print(bleu)
  40. # result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}
  41. result = {key: value * 100 for key, value in result.items()}
  42. result["gen_len"] = np.mean(labels_lens)
  43. result["bleu"] = bleu * 100
  44. return result

2 QA&QG

流程比较复杂,例如基于QA的,需要分别训练Question generation与Answer generation模型,模型的训练好坏直接影响效果。先简单介绍QA与QG的训练,其中QA基于BERT,QG基于BART,这里是用的是英文的SQuAD-1.1,中文方法是一样的,使用CMRC2018的SQuAD格式数据,模型换成中文模型就好了。不同论文的实现方式不一样,我只说一个最最简单的方法。

2.1目录结构

2.2 数据加载

squad.py,我这里cmrc2018.py与squad.py一样,只是数据集地址不一样

重要的是这个地方,加载的是数据集的位置

  1. # coding=utf-8
  2. # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Lint as: python3
  16. """SQUAD: The Stanford Question Answering Dataset."""
  17. import json
  18. import datasets
  19. from datasets.tasks import QuestionAnsweringExtractive
  20. logger = datasets.logging.get_logger(__name__)
  21. _CITATION = """\
  22. @article{2016arXiv160605250R,
  23. author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},
  24. Konstantin and {Liang}, Percy},
  25. title = "{SQuAD: 100,000+ Questions for Machine Comprehension of Text}",
  26. journal = {arXiv e-prints},
  27. year = 2016,
  28. eid = {arXiv:1606.05250},
  29. pages = {arXiv:1606.05250},
  30. archivePrefix = {arXiv},
  31. eprint = {1606.05250},
  32. }
  33. """
  34. _DESCRIPTION = """\
  35. Stanford Question Answering Dataset (SQuAD) is a reading comprehension \
  36. dataset, consisting of questions posed by crowdworkers on a set of Wikipedia \
  37. articles, where the answer to every question is a segment of text, or span, \
  38. from the corresponding reading passage, or the question might be unanswerable.
  39. """
  40. _URL = r"E:\Project\NLP\dataset\SQuAD-1.1 datasets/"
  41. _URLS = {
  42. "train": _URL + "train-v1.1.json",
  43. "dev": _URL + "dev-v1.1.json",
  44. }
  45. class SquadConfig(datasets.BuilderConfig):
  46. """BuilderConfig for SQUAD."""
  47. def __init__(self, **kwargs):
  48. """BuilderConfig for SQUAD.
  49. Args:
  50. **kwargs: keyword arguments forwarded to super.
  51. """
  52. super(SquadConfig, self).__init__(**kwargs)
  53. class Squad(datasets.GeneratorBasedBuilder):
  54. """SQUAD: The Stanford Question Answering Dataset. Version 1.1."""
  55. BUILDER_CONFIGS = [
  56. SquadConfig(
  57. name="plain_text",
  58. version=datasets.Version("1.0.0", ""),
  59. description="Plain text",
  60. ),
  61. ]
  62. def _info(self):
  63. return datasets.DatasetInfo(
  64. description=_DESCRIPTION,
  65. features=datasets.Features(
  66. {
  67. "id": datasets.Value("string"),
  68. "title": datasets.Value("string"),
  69. "context": datasets.Value("string"),
  70. "question": datasets.Value("string"),
  71. "answers": datasets.features.Sequence(
  72. {
  73. "text": datasets.Value("string"),
  74. "answer_start": datasets.Value("int32"),
  75. }
  76. ),
  77. }
  78. ),
  79. # No default supervised_keys (as we have to pass both question
  80. # and context as input).
  81. supervised_keys=None,
  82. homepage="https://rajpurkar.github.io/SQuAD-explorer/",
  83. citation=_CITATION,
  84. task_templates=[
  85. QuestionAnsweringExtractive(
  86. question_column="question", context_column="context", answers_column="answers"
  87. )
  88. ],
  89. )
  90. def _split_generators(self, dl_manager):
  91. downloaded_files = dl_manager.download_and_extract(_URLS)
  92. return [
  93. datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
  94. datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}),
  95. ]
  96. def _generate_examples(self, filepath):
  97. """This function returns the examples in the raw (text) form."""
  98. logger.info("generating examples from = %s", filepath)
  99. key = 0
  100. with open(filepath, encoding="utf-8") as f:
  101. squad = json.load(f)
  102. for article in squad["data"]:
  103. title = article.get("title", "")
  104. for paragraph in article["paragraphs"]:
  105. context = paragraph["context"] # do not strip leading blank spaces GH-2585
  106. for qa in paragraph["qas"]:
  107. answer_starts = [answer["answer_start"] for answer in qa["answers"]]
  108. answers = [answer["text"] for answer in qa["answers"]]
  109. # Features currently used are "context", "question", and "answers".
  110. # Others are extracted here for the ease of future expansions.
  111. yield key, {
  112. "title": title,
  113. "context": context,
  114. "question": qa["question"],
  115. "id": qa["id"],
  116. "answers": {
  117. "answer_start": answer_starts,
  118. "text": answers,
  119. },
  120. }
  121. key += 1

2.3 QA训练

QA_finetune.py

  1. # coding=utf-8
  2. import json
  3. import numpy as np
  4. import torch
  5. from datasets import Dataset,load_dataset
  6. from transformers.data.metrics.squad_metrics import compute_exact, compute_f1
  7. squad = load_dataset('./squad.py')
  8. squad2 = squad["validation"][0]
  9. # 获取验证数据集
  10. # valid_datasets = squad["validation"].flatten().data
  11. # # 获取验证数据集中的content
  12. # valid_contents = valid_datasets[2]
  13. # # 获取验证数据集中的gold answer
  14. # valid_answers = valid_datasets[4]
  15. # # 获取验证数据集中的gold answer start
  16. # valid_answers_start = valid_datasets[5]
  17. xx = compute_f1("left Graz and severed all relations with his family","left Graz and severed")
  18. print(squad)
  19. from transformers import AutoTokenizer, default_data_collator, BertForQuestionAnswering, TrainingArguments, Trainer, \
  20. BertTokenizer
  21. tokenizer = AutoTokenizer.from_pretrained(r"E:\Project\NLP\bert-base-uncased")
  22. def preprocess_function(examples):
  23. '''
  24. 用于处理训练集,因为训练集每个问题只有一个参考答案回答
  25. '''
  26. questions = [q.strip() for q in examples["question"]]
  27. inputs = tokenizer(
  28. questions,
  29. examples["context"],
  30. max_length=512,
  31. truncation="only_second",
  32. return_offsets_mapping=True,
  33. padding="max_length",
  34. )
  35. offset_mapping = inputs.pop("offset_mapping")
  36. answers = examples["answers"]
  37. start_positions = []
  38. end_positions = []
  39. for i, offset in enumerate(offset_mapping):
  40. answer = answers[i]
  41. start_char = answer["answer_start"][0]
  42. end_char = answer["answer_start"][0] + len(answer["text"][0])
  43. sequence_ids = inputs.sequence_ids(i)
  44. # Find the start and end of the context
  45. idx = 0
  46. while sequence_ids[idx] != 1:
  47. idx += 1
  48. context_start = idx
  49. while sequence_ids[idx] == 1:
  50. idx += 1
  51. context_end = idx - 1
  52. # If the answer is not fully inside the context, label it (0, 0)
  53. if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
  54. start_positions.append(0)
  55. end_positions.append(0)
  56. else:
  57. # Otherwise it's the start and end token positions
  58. idx = context_start
  59. while idx <= context_end and offset[idx][0] <= start_char:
  60. idx += 1
  61. start_positions.append(idx - 1)
  62. idx = context_end
  63. while idx >= context_start and offset[idx][1] >= end_char:
  64. idx -= 1
  65. end_positions.append(idx + 1)
  66. inputs["start_positions"] = start_positions
  67. inputs["end_positions"] = end_positions
  68. return inputs
  69. tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)
  70. # train_x = squad["train"].map(preprocess_function, batched=True, remove_columns=squad["train"].column_names)
  71. # valid_x = squad["validation"].map(preprocess_function, batched=True, remove_columns=squad["validation"].column_names)
  72. data_collator = default_data_collator
  73. model = BertForQuestionAnswering.from_pretrained(r"E:\Project\NLP\long-document\bert-base-uncased")
  74. training_args = TrainingArguments(
  75. # fp16 = True,
  76. output_dir="./QA_results",
  77. do_train=True,
  78. do_eval=True,
  79. evaluation_strategy="epoch",
  80. # eval_steps=2,
  81. learning_rate=1e-4,
  82. per_device_train_batch_size=1,
  83. per_device_eval_batch_size=1,
  84. logging_dir="logs",
  85. logging_strategy="steps",
  86. save_total_limit=3,
  87. logging_steps=1,
  88. num_train_epochs=4,
  89. weight_decay=0.01,
  90. gradient_accumulation_steps=8,
  91. )
  92. def compute_metrics(eval_pred):
  93. predictions,label_ids = eval_pred
  94. start = predictions[0]
  95. end = predictions[1]
  96. answer_start = np.argmax(predictions[0],axis = 1)
  97. answer_end = np.argmax(predictions[1],axis = 1)
  98. label_start = label_ids[0] # 这个是token过后的开始与结束为止 不是一个一个字符数的 是一个一个单词数的
  99. label_end = label_ids[1]
  100. data = tokenized_squad["validation"]
  101. gold_answers=[]
  102. pred_answers=[]
  103. # 遍历每一个验证数据
  104. for idx,example in enumerate(data):
  105. input_ids = example["input_ids"]# 取出文本
  106. label_start = example["start_positions"]# 取出开始
  107. label_end = example["end_positions"]# 取出结束
  108. gold_answer=""
  109. pred_answer=""
  110. for i in range(label_end-label_start+1):
  111. gold_answer+=str(input_ids[label_start+i])+" "
  112. if answer_start[idx] < answer_end[idx]:
  113. answer_end[idx] = answer_start[idx]
  114. for i in range(answer_end[idx]-answer_start[idx]+1):
  115. pred_answer+=str(input_ids[label_start+i])+" "
  116. gold_answers.append(gold_answer.strip())
  117. pred_answers.append(pred_answer.strip())
  118. # 计算f1 score与exact score
  119. f1_score=0
  120. exact_score=0
  121. for gold_answer,pred_answer in zip(gold_answers,pred_answers):
  122. f1_score+=compute_f1(gold_answer,pred_answer)
  123. exact_score+=compute_exact(gold_answer,pred_answer)
  124. f1_score/=len(gold_answers)
  125. exact_score/=len(gold_answers)
  126. f1_score*=100
  127. exact_score*=100
  128. result = {'f1_score': f1_score, 'exact_score': exact_score}
  129. return result
  130. trainer = Trainer(
  131. model=model,
  132. args=training_args,
  133. train_dataset=tokenized_squad["train"],
  134. eval_dataset=tokenized_squad["validation"],
  135. data_collator=data_collator,
  136. tokenizer=tokenizer,
  137. compute_metrics=compute_metrics
  138. )
  139. trainer.train()

 2.4 QG训练

QG_finetune.py

  1. import json
  2. import numpy as np
  3. from datasets import Dataset
  4. from transformers import AutoTokenizer, BartTokenizer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, \
  5. BartForConditionalGeneration, Seq2SeqTrainer
  6. from transformers.data.metrics.squad_metrics import compute_f1
  7. # x = ['Which NFL team represented the AFC at Super Bowl 50?', 'Which NFL team represented the NFC at Super Bowl 50?', 'Where did Super Bowl 50 take place?', 'Which NFL team won Super Bowl 50?', 'What color was used to emphasize the 50th anniversary of the Super Bowl?', 'What was the theme of Super Bowl 50?', 'What day was the game played on?', 'What is the AFC short for?', 'What was the theme of Super Bowl 50?', 'What does AFC stand for?', 'What day was the Super Bowl played on?', 'Who won Super Bowl 50?', 'What venue did Super Bowl 50 take place in?', 'What city did Super Bowl 50 take place in?', 'If Roman numerals were used, what would Super Bowl 50 have been called?', 'Super Bowl 50 decided the NFL champion for what season?', 'What year did the Denver Broncos secure a Super Bowl title for the third time?', 'What city did Super Bowl 50 take place in?', 'What stadium did Super Bowl 50 take place in?', 'What was the final score of Super Bowl 50? ']
  8. # xx = ['Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League', 'Super Bowl 50 was an American football game to determine the champion of the National Football League']
  9. max_input_length=512
  10. max_target_length=128
  11. train_path = r'E:\Project\NLP\dataset\SQuAD-1.1 datasets\train-v1.1.json'
  12. dev_path = r'E:\Project\NLP\dataset\SQuAD-1.1 datasets\dev-v1.1.json'
  13. output_dir=r'E:\Project\NLP\dataset\SQuAD-1.1 datasets\QG_results'
  14. tokenizer_path=r'E:\Project\NLP\bart-base-english'
  15. model_path=r'E:\Project\NLP\bart-base-english'
  16. def data_preprocess(path):
  17. with open(path, 'r', encoding='utf-8') as f_train:
  18. train_set = json.load(f_train)
  19. datas = train_set
  20. # convert
  21. new_data = []
  22. for data in datas["data"]:
  23. for d in data['paragraphs']:
  24. context = d['context']
  25. for qa in d['qas']:
  26. new_data.append({
  27. 'context': context,
  28. 'answers': qa['answers'],
  29. 'question': qa['question']
  30. })
  31. contexts=[]
  32. labels=[]
  33. for data in new_data:
  34. answer_text = data['answers'][0]['text']
  35. answer_len = len(answer_text)
  36. answer_start = data['answers'][0]['answer_start']
  37. hl_context = data['context'][:answer_start] +'<hl>' + answer_text + '<hl>' + data['context'][answer_start + answer_len:]
  38. label=data['question'] #+ '</s>'
  39. contexts.append(hl_context)
  40. labels.append(label)
  41. return contexts, labels
  42. train_contexts,train_labels=data_preprocess(train_path)
  43. dev_contexts,dev_labels=data_preprocess(dev_path)
  44. train={}
  45. dev={}
  46. train["contexts"]=train_contexts#[0:50]
  47. train["labels"]=train_labels#[0:50]
  48. dev["contexts"]=dev_contexts#[0:20]
  49. dev["labels"]=dev_labels#[0:20]
  50. train_dataset=Dataset.from_dict(train)
  51. train_dataset=train_dataset.shuffle(seed=42)
  52. dev_dataset=Dataset.from_dict(dev)
  53. dev_dataset=dev_dataset.shuffle(seed=42)
  54. tokenizer = BartTokenizer.from_pretrained(tokenizer_path)
  55. special_tokens_dict = {'additional_special_tokens': ['<hl>']}
  56. tokenizer.add_special_tokens(special_tokens_dict)
  57. def preprocess_function(examples):
  58. inputs = [doc for doc in examples["contexts"]]
  59. model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
  60. # Setup the tokenizer for targets
  61. with tokenizer.as_target_tokenizer():
  62. labels = tokenizer(examples["labels"], max_length=max_target_length, truncation=True)
  63. # title_len_1 = tokenizer(examples["len_title_1"], max_length=max_target_length, truncation=True)
  64. # title_len_all = tokenizer(examples["len_title_all"], max_length=max_target_length, truncation=True)
  65. model_inputs["labels"] = labels["input_ids"]
  66. return model_inputs
  67. tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
  68. tokenized_dev_dataset = dev_dataset.map(preprocess_function, batched=True, remove_columns=dev_dataset.column_names)
  69. batch_size = 1
  70. args = Seq2SeqTrainingArguments(
  71. fp16 = True,
  72. output_dir=output_dir,
  73. num_train_epochs=5, # demo
  74. do_train=True,
  75. do_eval=True,
  76. per_device_train_batch_size=1, # demo
  77. per_device_eval_batch_size=1,
  78. learning_rate=1e-04,
  79. warmup_steps=100,
  80. weight_decay=0.01,
  81. label_smoothing_factor=0.1,
  82. predict_with_generate=True,
  83. logging_dir="logs",
  84. logging_strategy="steps",
  85. logging_steps=1,
  86. save_total_limit=3,
  87. evaluation_strategy="epoch",
  88. generation_max_length=max_target_length,
  89. generation_num_beams=4,
  90. # remove_unused_columns=False,
  91. )
  92. model = BartForConditionalGeneration.from_pretrained(model_path)
  93. model.resize_token_embeddings(len(tokenizer))
  94. # model, list_en, list_de = create_student_by_copying_alternating_layers(model, 'trian.pth', 12, 3)
  95. data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
  96. def compute_metrics(eval_pred):
  97. predictions, labels = eval_pred
  98. decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  99. # Replace -100 in the labels as we can't decode them.
  100. labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  101. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  102. f1_score = 0
  103. for label, pred in zip(decoded_labels, decoded_preds):
  104. f1_score += compute_f1(label, pred)
  105. f1_score /= len(decoded_preds)
  106. f1_score *= 100
  107. result = {'f1_score': f1_score}
  108. return result
  109. trainer = Seq2SeqTrainer(
  110. model,
  111. args,
  112. train_dataset=tokenized_train_dataset,
  113. # train_dataset=dataset_train,
  114. eval_dataset=tokenized_dev_dataset,
  115. # eval_dataset=dataset_valid,
  116. data_collator=data_collator,
  117. tokenizer=tokenizer,
  118. compute_metrics=compute_metrics,
  119. )
  120. train_result = trainer.train()

2.5 使用

中文是一样的,不过数据集换一下,注意格式!训练好之后就可以通过QG对源文本生成问题与答案,将问题输入AG模型生成回答,比较源文本的回答与生成摘要的回答结果来比较,这样做不需要参考摘要!!!

3 BERTScore

这是一个介于rouge与QA之间的指标,通过将参考摘要与生成摘要输入bert模型,获得句向量,比较句向量之间的cosin相似度来获得指标,使用很简单,只需要调用库函数就行。

  1. from bert_score import score
  2. # data
  3. cands = ['天天干家务烦死了','难受死了啊']
  4. refs = ['这也完全不相干啊','真的难受死了啊']
  5. P, R, F1 = score(cands, refs, lang="zh", verbose=True)
  6. print(f"System level F1 score: {F1.mean():.3f}")

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/701706
推荐阅读
相关标签
  

闽ICP备14008679号