当前位置:   article > 正文

创新实训(12)-生成式文本摘要之T5

创新实训(12)-生成式文本摘要之T5

创新实训(12)-生成式文本摘要之T5

1.简介

T5:Text-To-Text-Transfer-Transformer的简称,是Google在2019年提出的一个新的NLP模型。它的基本思想就是Text-to-Text,即NLP的任务基本上都可以归为从文本到文本的处理过程。

T501

上图就是论文中的一个图,形象的展示了“Text-To-Text”的过程。

2. 模型

在论文中,作者做了大量的实验,最终发现还是Encoder-Decoder的模型表现最好,最终就选择了它,所以T5是一个基于TransformerEncoder-Decoder模型。

对于预训练的方式,论文作者也进行了许多实验,最终发现类似Bert的将一部分破坏掉的方式效果最好,而破坏的策略则是Replace spans方法最好,破坏的比例是15%最好,破坏的长度发现是3最好。如论文中下图所示:

model

3.数据集

作者从Common Crawl(一个公开的网页存档数据集)中清洗出了750GB的训练数据,取名为Colossal Clean Crawled Corpus,简称C4,不得不说这作者真会取名字。

然后作者基于此数据集进行了大量的实验,当数据量达到一定的规模之后,继续增大数据量,效果并没有明显的提高,但是大模型是必须的。

4. 复现

依然是使用colab

4.1 导入模块

import

4.1 定义DataSet

DataSetPyTorch的一个用于数据集加载的类,我们需要继承它,重写数据处理方法。

class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, source_len, summ_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.source_len = source_len
        self.summ_len = summ_len
        self.text = self.data.text
        self.ctext = self.data.ctext

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        ctext = str(self.ctext[index])
        ctext = ' '.join(ctext.split())

        text = str(self.text[index])
        text = ' '.join(text.split())

        source = self.tokenizer.batch_encode_plus([ctext], max_length= self.source_len, pad_to_max_length=True,return_tensors='pt')
        target = self.tokenizer.batch_encode_plus([text], max_length= self.summ_len, pad_to_max_length=True,return_tensors='pt')

        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()
        target_mask = target
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/480073
推荐阅读
相关标签
  

闽ICP备14008679号