当前位置:   article > 正文

微调T5构建英文摘要模型_t5文本摘要csdn

t5文本摘要csdn

基本介绍

T5全称是Text-to-Text Transfer Transformer,是一种基于Transformer的预训练模型,由Google Brain团队于2019年提出。T5在自然语言处理(NLP)任务中表现出色,并在多个公开数据集上取得了领先的性能。相比于传统的预训练模型如BERT、GPT等,T5拥有更为灵活的框架。T5的核心思想是"Text-To-Text Transfer",即将所有NLP任务都转化为文本到文本(Text-To-Text)问题,然后使用统一的方式进行处理和训练。

T5的前缀

T5模型的前缀可以根据任务需求进行自定义。以下是一些常见任务的前缀示例:

1.摘要任务前缀:

"summarize: ":用于指定需要进行摘要的文本

2.翻译任务前缀:

"translate English to French: ":将英文翻译为法文 

"translate English to Spanish: ":将英文翻译为西班牙文

对于代码部分,可以使用Transformers库中的T5Tokenizer来添加前缀,下面代码仅以摘要为例,其他功能同理。

模型推理代码

这里使用hugging face的t5-small预训练模型,下载地址 t5-small · Hugging Face

cache_dir:指定了模型文件和tokenizer文件的缓存目录,如果已经下载过,它会直接从这个目录中读取这些文件,避免了重复下载。受限于国内环境,还是建议提前下载好模型文件。

T5Tokenizer.from_pretrained: 从预训练模型中加载所需的tokenizer,并返回一个tokenizer对象,使用方法与其他Hugging Face的tokenizer相同,这里使用了T5-small模型的tokenizer

T5ForConditionalGeneration.from_pretrained: 加载预先训练好的T5模型,并返回一个T5模型对象,可以用来进行推理和生成。

  1. from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline
  2. # 加载T5模型和Tokenizer
  3. cache_dir = 'model/t5_small'
  4. tokenizer = T5Tokenizer.from_pretrained(cache_dir)
  5. model = T5ForConditionalGeneration.from_pretrained(cache_dir)
  6. print(model.config) #打印相关信息
  7. print(model) #打印模型
  8. text = '''The Eiffel Tower is one of the most famous landmarks in Paris, France, and one of the most famous buildings in the world. It was designed and built by the famous French engineer Gustave Eiffel. The Paris Tower was built for the 1889 Paris World Exposition to celebrate the 100th anniversary of the French Revolution.
  9. The Paris Tower is located on Champs de Mars in the seventh arrondissement of Paris, France. It is a steel structure tower that reaches a height of 324 meters (approximately 1063 feet). At its completion, it was the world's tallest man-made building until it was surpassed by the Chrysler Building in New York in 1930.
  10. The iron tower is divided into three observation platforms: the first layer, the second layer, and the third layer. You can take the elevator or climb the stairs to different observation platforms and enjoy the magnificent city scenery of Paris. The top level observation deck is the most popular, where tourists can overlook the entire city of Paris, including museums, churches, and the famous Seine River around the Eiffel Tower'''
  11. input_ids = tokenizer("summarization:"+text, return_tensors="pt").input_ids
  12. outputs = model.generate(input_ids, num_beams=4, early_stopping=True)
  13. summary_text=tokenizer.decode(outputs[0], skip_special_tokens=True)
  14. print(summary_text)
  15. # output:The Eiffel Tower was built for the 1889 Paris World Exposition to celebrate the 100th anniversary of the French Revolution. It is a steel structure tower that reaches a height of 324 meters (approximately 1063 feet) at its completion, it was the world's tallest man-made building until it was surpassed by the Chrysler Building in New York in 1930.
  16. #也可以使用pipline
  17. summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, framework="tf")
  18. result = summarizer(
  19. text,
  20. min_length=5,
  21. max_length=128,
  22. )
  23. print(result)

模型微调代码

下面使用xsum数据集对hugging face的T5模型进行文本摘要任务的微调,该数据集可以通过datasets库进行下载。

  1. # 下载xsum数据集
  2. from datasets import load_dataset
  3. raw_datasets = load_dataset("xsum")

训练结果使用Rouge作为摘要评价指标,衡量生成的摘要与参考摘要之间的相似度。

可以使用下面代码下载评估函数:

  1. from datasets import load_metric
  2. metric = load_metric('rouge')

我因为之前下载过了,所以在代码中加了缓存路径。

模型微调代码主体如下,注意代码中的超参数并没有做精确调优。

  1. from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5Tokenizer, T5ForConditionalGeneration
  2. from transformers import AutoTokenizer
  3. import torch
  4. import numpy as np
  5. from datasets import load_metric
  6. cache_dir = 'model/t5_small'
  7. tokenizer = T5Tokenizer.from_pretrained(cache_dir)
  8. model = T5ForConditionalGeneration.from_pretrained(cache_dir)
  9. print(model.config) #打印相关信息
  10. print(model) #打印模型
  11. device = torch.device('cuda')
  12. model.to(device)
  13. # 设置参数
  14. MAX_INPUT_LENGTH = 1024 # 模型输入的最大长度
  15. MIN_TARGET_LENGTH = 5 # 模型输出的最小长度
  16. MAX_TARGET_LENGTH = 64 # 模型输出的最大长度
  17. # 微调
  18. batch_size = 16
  19. args = Seq2SeqTrainingArguments( # 这里默认使用AdamW优化器
  20. output_dir="model/t5_small/test-summarization",
  21. evaluation_strategy="epoch", # 每个epoch会做一次验证评估
  22. learning_rate=2e-5,
  23. per_device_train_batch_size=batch_size,
  24. per_device_eval_batch_size=batch_size,
  25. weight_decay=0.01,
  26. save_total_limit=2,
  27. num_train_epochs=10,
  28. predict_with_generate=True, # 设置为True才能计算生成指标
  29. fp16=True,
  30. )
  31. data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
  32. # 评估函数
  33. metric = load_metric('rouge', cache_dir='model/metrics/rouge1')
  34. def compute_metrics(eval_pred):
  35. predictions, labels = eval_pred
  36. decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  37. # Replace -100 in the labels as we can't decode them.
  38. labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  39. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  40. # Rouge需要每句话都换行
  41. decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
  42. decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
  43. result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
  44. # Extract a few results
  45. result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
  46. # Add mean generated length
  47. prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
  48. result["gen_len"] = np.mean(prediction_lens)
  49. return {k: round(v, 4) for k, v in result.items()}
  50. # 训练
  51. trainer = Seq2SeqTrainer(
  52. model,
  53. args,
  54. train_dataset=train_dataset,
  55. eval_dataset=val_dataset,
  56. data_collator=data_collator,
  57. tokenizer=tokenizer,
  58. compute_metrics=compute_metrics # 传递函数
  59. )
  60. trainer.train()
  61. # 保存模型
  62. save_dir = "model/t5/trained_t5"
  63. model.save_pretrained(save_dir)
  64. tokenizer.save_pretrained(save_dir)

示例代码没有体现数据处理过程, train_dataset和val_dataset根据自己需要进行预处理后,输入模型即可,不要忘记加后缀。


                
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/348518
推荐阅读
相关标签
  

闽ICP备14008679号