当前位置:   article > 正文

使用 T5 模型完成新闻摘要任务

使用 T5 模型完成新闻摘要任务

前文

内容摘要是自然语言处理(NLP)的核心问题之一,实现这一能目标必须让模型具备语言理解和内容生成两大能力。本文使用新闻数据,通过微调 T5 模型来完成提取新闻摘要这一任务。

要实现这一任务需要使用 Seq2Seq 模型,这也是为什么选中 T5 的原因。 Text-to-Text Transfer Transformer (T5) 是一种基于 Transformer 的模型,构建在编码器-解码器架构上,在众多无监督和监督任务中进行了多任务混合预训练,其中每个任务都是文本到文本的形式。T5 在摘要、翻译等领域表现很出色。

在这里插入图片描述

数据

本文要使用的数据集是 XSum 数据集,该数据集每个样本由 BBC 文章 document 和对应的单句摘要 summary 组成。下面展示一个样本如下:

{
'document': 'The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed.\nRepair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water.\nTrains on the west coast mainline face disruption due to damage at the Lamington Viaduct.\nMany businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town.\nFirst Minister Nicola Sturgeon visited the area to inspect the damage.\nThe waters breached a retaining wall, flooding many commercial properties on Victoria Street - the main shopping thoroughfare.\nJeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit.\nHowever, she said more preventative work could have been carried out to ensure the retaining wall did not fail.\n"It is difficult but I do think there is so much publicity for Dumfries and the Nith - and I totally appreciate that - but it is almost like we\'re neglected or forgotten," she said.\n"That may not be true but it is perhaps my perspective over the last few days.\n"Why were you not ready to help us a bit more when the warning and the alarm alerts had gone out?"\nMeanwhile, a flood alert remains in place across the Borders because of the constant rain.\nPeebles was badly hit by problems, sparking calls to introduce more defences in the area.\nScottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs.\nThe Labour Party\'s deputy Scottish leader Alex Rowley was in Hawick on Monday to see the situation first hand.\nHe said it was important to get the flood protection plan right but backed calls to speed up the process.\n"I was quite taken aback by the amount of damage that has been done," he said.\n"Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses."\nHe said it was important that "immediate steps" were taken to protect the areas most vulnerable and a clear timetable put in place for flood prevention plans.\nHave you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. Email us on selkirk.news@bbc.co.uk or dumfries@bbc.co.uk.', 
'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.', 
'id': '35232142'
} 
  • 1
  • 2
  • 3
  • 4
  • 5

为了和我们使用的大模型 T5 配套使用,我们处理数据的 tokenizer 也是加载自 T5 ,这样处理的结果符合模型的输入要求。 其中具体的工作主要是在 document 前面加上 "summarize: ", 然后将各个 documentsummary 处理成对应的 input_ids ,token_type_ids ,attention_mask

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=MAX_TARGET_LENGTH, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

另外为了符合训练 Seq2Seq 模型的要求,我们需要一种特殊的数据处理器,它不仅可以将输入扩充到 batch 中的最长样本的长度,还可以填充标签。我们这里选用的是 DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")
train_dataset = tokenized_datasets["train"].to_tf_dataset(batch_size=BATCH_SIZE, columns=["input_ids", "attention_mask", "labels"], shuffle=True, collate_fn=data_collator)
test_dataset = tokenized_datasets["test"].to_tf_dataset(batch_size=BATCH_SIZE,  columns=["input_ids", "attention_mask", "labels"], shuffle=False, collate_fn=data_collator)
generation_dataset = tokenized_datasets["test"].shuffle().select(list(range(200))).to_tf_dataset(batch_size=BATCH_SIZE,   columns=["input_ids", "attention_mask", "labels"], shuffle=False, collate_fn=data_collator)
  • 1
  • 2
  • 3
  • 4

模型

为了更好评估我们的模型效果,我们会计算真实序列和预测序列之间的 ROUGE scoreROUGE(Recall-Oriented Understudy for Gisting Evaluation) 是一种用于自动评估文本摘要质量的指标。它主要用于衡量生成的摘要与真实摘要之间的相似性,是评估文本摘要生成系统性能的常用指标之一。

def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge_l(decoded_labels, decoded_predictions)
    result = {"RougeL": result["f1_score"]}
    return result
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

因为显存有限,所以选择了 t5-small 这个小模型,只有 6 层结构, 大约 93M 的参数,属于 t5 家族中较小的模型,但是足够胜任本次的任务。

另外为了微调模型能达到最佳的效果,我们使用全部数据来进行预训练,并进行 5 个 epoch ,batch_size 是 8 个,需要 17G 左右的显存,耗时将近 4.5 小时

Epoch 1/5
20405/20405 [==============================] - 3240s 158ms/step - loss: 2.6656 - val_loss: 2.3561 - RougeL: 0.2318
Epoch 2/5
20405/20405 [==============================] - 3200s 157ms/step - loss: 2.5085 - val_loss: 2.2873 - RougeL: 0.2393
Epoch 3/5
20405/20405 [==============================] - 3199s 157ms/step - loss: 2.4330 - val_loss: 2.2488 - RougeL: 0.2438
Epoch 4/5
20405/20405 [==============================] - 3212s 157ms/step - loss: 2.3785 - val_loss: 2.2172 - RougeL: 0.2468
Epoch 5/5
20405/20405 [==============================] - 3239s 159ms/step - loss: 2.3335 - val_loss: 2.1938 - RougeL: 0.2493
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

推理

推理的过程我们不用编写复杂的推理代码,Hugging Face Transformers 库中的 pipeline 工具已经针对各种任务提供了各种流水线工具,比如这里我们使用 summarization 工具,再结合我们经过微调的大模型 t5-small ,即可对给定的输入文档生成摘要。

summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, framework="tf")
for i in range(5):
    summarizer(raw_datasets["test"][i]["document"], min_length=MIN_TARGET_LENGTH, max_length=MAX_TARGET_LENGTH)
    print(raw_datasets["test"][i]["summary"])
  • 1
  • 2
  • 3
  • 4

推理结果和原摘要进行对比,可以看出效果一般,主要还是在于模型的体量太小了:

预测:[{'summary_text': 'The Australian government is trying to verify the deaths of two Lebanese fighters who left the country to fight with Islamic State militants.'}]
真实:The Australian government says new citizenship laws will help combat terrorism but experts worry too many people will be affected by the laws, including children.
预测:[{'summary_text': "It's been a week since the murder of Lucky the donkey, who was found dead at her home in Fermanagh, County Down."}]
真实:On the day of her burial, County Fermanagh murder victim Connie Leonard's legacy is already being felt.
预测:[{'summary_text': 'Insurance claims for catastrophic storms in Australia have risen by more than A$8.8m (£8.6m) in the past year, a council has said.'}]
真实:Severe storms that hit Australia during April and May have led to more than A$1.55bn ($1.18bn; £778m) in insurance losses so far.
预测:[{'summary_text': "It's been a long time since I was selected by BBC Sport to pick my own team of the week for BBC Sport."}]
真实:Manchester United did Chelsea's title rivals Tottenham a favour and kept up their own pursuit of the top four with a dominant win over the Premier League leaders.
预测:[{'summary_text': 'A Belarusian politician has been found guilty of tax evasion after he held bank accounts in Poland and Lithuania.'}]
真实:One of Belarus' most prominent human rights activists has been sentenced to four-and-a-half years in prison for tax evasion.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

小结

文本只是简单展示新闻摘要的实现过程,其实对于这类简单的文本摘要任务,稍微大一点的 BERT 或者 T5 这种传统的模型也是可以胜任的,如果没有更高精度的业务需求没有必要一味追求 LLM ,毕竟对于个人或者中小企业来说目前使用 LLM 的成本较高。

参考

我的代码仓库
我的原文

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/480095
推荐阅读
相关标签
  

闽ICP备14008679号