当前位置:   article > 正文

使用预训练模型自动续写文本的四种方法

使用预训练模型自动续写文本的四种方法

作者:皮皮雷 来源:投稿
编辑:学姐

这篇文章以中文通用领域文本生成为例,介绍四种常用的模型调用方法。在中文文本生成领域,huggingface上主要有以下比较热门的pytorch-based预训练模型:

本文用到了其中的uer/gpt2-chinese-cluecorpussmall和hfl/chinese-xlnet-base,它们都是在通用领域文本上训练的。

但是要注意有些模型(如CPM-Generate共有26亿参数)模型文件较大,GPU有限的情况下可能会OOM。

依赖包:transformers 4

本文使用的例句来源于豆瓣爬下的部分书评。

方法1:transformers.pipline

简介:

直接调用transformers里面的pipline。

源码及参数选择参考:

https://huggingface.co/docs/transformers/v4.17.0/en/main_classes/pipelines#transformers.pipeline

缺点:不能以batch形式生成句子,不能并行,大规模调用的时候时间复杂度较高。

  1. from transformers import pipeline
  2. #this pipline can only generate text one by one
  3. generator = pipeline(
  4.     'text-generation'
  5.     model="uer/gpt2-chinese-cluecorpussmall",  #可以直接写huggingface上的模型名,也可以写本地的模型地址
  6.     device = 1)
  7. text_inputs = ["客观、严谨、浓缩",
  8.                 "地摊文学……",
  9.                 "什么鬼玩意,",
  10.                 "豆瓣水军果然没骗我。",
  11.                 "这是一本社会新闻合集",
  12.                 "风格是有点学古龙嘛?但是不好看。"]
  13. sent_gen = generator(text_inputs, 
  14.                         max_length=100
  15.                         num_return_sequences=2,
  16.                         repetition_penalty=1.3
  17.                         top_k = 20
  18. #返回的sent_gen 形如#[[{'generated_text':"..."},{}],[{},{}]]
  19. for i in sent_gen:
  20.     print(i)

方法2:transformers中的TextGenerationPipeline类

源码及参数选择参考:

https://huggingface.co/docs/transformers/v4.17.0/en/main_classes/pipelines#transformers.TextGenerationPipeline

优点:相较方法1,可以设置batch size。

  1. from transformers import BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline
  2. tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
  3. model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
  4. text_generator = TextGenerationPipeline(model, tokenizer, batch_size=3, device=1)
  5. text_inputs = ["客观、严谨、浓缩",
  6.                 "地摊文学……",
  7.                 "什么鬼玩意,",
  8.                 "豆瓣水军果然没骗我。",
  9.                 "这是一本社会新闻合集",
  10.                 "风格是有点学古龙嘛?但是不好看。"]
  11. gen = text_generator(text_inputs, 
  12.                     max_length=100
  13.                     repetition_penalty=10.0
  14.                     do_sample=True
  15.                     num_beams=5,
  16.                     top_k=10)
  17. for sent in gen:
  18.     gen_seq = sent[0]["generated_text"]
  19.     print("")
  20.     print(gen_seq.replace(" ",""))

方法3:transformers通用方法,直接加载模型

源码及参数选择参考:

https://github.com/huggingface/transformers/blob/c4d4e8bdbd25d9463d41de6398940329c89b7fb6/src/transformers/generation_utils.py#L101

缺点:封装度较差,代码较为冗长。

优点:由于是transformers调用模型的通用写法,和其他模型(如bert)的调用方式相似,(如tokenizer的使用),可以举一反三。

  1. from transformers import AutoTokenizer, AutoModelWithLMHead
  2. import torch, os
  3. os.environ["CUDA_VISIBLE_DEVICES"= "2"
  4. tokenizer = AutoTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
  5. model = AutoModelWithLMHead.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
  6. config=model.config
  7. print(config)
  8. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  9. model = model.to(device)
  10. texts = ["客观、严谨、浓缩",
  11.                 "地摊文学……",
  12.                 "什么鬼玩意,",
  13.                 "豆瓣水军果然没骗我。",
  14.                 "这是一本社会新闻合集",
  15.                 "风格是有点学古龙嘛?但是不好看。"]
  16. #用batch输入的时候一定要设置padding
  17. encoding = tokenizer(texts, return_tensors='pt', padding=True).to(device)
  18. with torch.no_grad():
  19.     generated_ids = model.generate(**encoding, 
  20.     max_length=200
  21.     do_sample=True, #default = False
  22.     top_k=20, #default = 50
  23.     repetition_penalty=3.0 #default = 1.0use float
  24.     ) 
  25. generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  26. for l in generated_texts:
  27.     print(l)

方法4:Simple Transformers

简介:Simple Transformers基于HuggingFace的Transformers,对特定的NLP经典任务做了高度的封装。在参数的设置上也较为灵活,可以通过词典传入参数。模型的定义和训练过程非常直观,方便理解整个AI模型的流程,很适合NLP新手使用。

simple transformers 指南:

https://simpletransformers.ai/docs/language-generation-model/

优点:这个包集成了微调的代码,不仅可以直接做生成,进一步微调也非常方便。

缺点:有些中文模型不能直接输入huggingface上的模型名称进行自动下载,会报错找不到tokenizer文件,需要手动下载到本地。

$ pip install simpletransformers

下载中文生成模型到本地文件夹 models/chinese-xlnet-base

  1. from simpletransformers.language_generation import LanguageGenerationModel
  2. # import logging
  3. # logging.basicConfig(level=logging.INFO)
  4. # transformers_logger = logging.getLogger("transformers")
  5. # transformers_logger.setLevel(logging.WARNING)
  6. model = LanguageGenerationModel("xlnet", #model type
  7. "models/chinese-xlnet-base", #包含 .bin file的文件路径
  8. args={"max_length"50"repetition_penalty"1.3,"top_k":100})
  9. prompts =["客观、严谨、浓缩",
  10.                 "地摊文学……",
  11.                 "什么鬼玩意,",
  12.                 "豆瓣水军果然没骗我。",
  13.                 "这是一本社会新闻合集",
  14.                 "风格是有点学古龙嘛?但是不好看。"]
  15. for prompt in prompts:
  16.     # Generate text using the model. Verbose set to False to prevent logging generated sequences.
  17.     generated = model.generate(prompt, verbose=False)
  18.     print(generated)

观察:用gpt2-chinese-cluecorpussmall生成的文本

参数设置:

  1. max_length=100
  2. repetition_penalty=10.0
  3. do_sample=True
  4. top_k=10

注:每一段文字的开头(标蓝)是预先给定的prompt

PS:乍一看生成语句的流利度和自然度都较好,还挺像人话的;而且有些句子能够按照“书评”的方向写。但仔细看就会发现噪音较多,而且容易“自由发挥”而跑题。这就是自由文本生成的常见问题:因为过于自由而不可控。

那么如何将生成的文本限定在想要的格式或领域中呢?这就是可控文本生成的研究范围了。一个较为常见的做法是对GPT-2作增量训练,让模型熟悉当前的语境。

总结

本文列举和比较了四种使用pytorch调用生成式模型做文本生成的方式。分别是:

① transformers自带的pipline

② transformers中的TextGenerationPipeline类

③ transformers通用方法,直接加载模型

④ Simple Transformers

这些方法各有优缺点。如果需要后续微调,建议使用③或④。如果只是简单地体验生成效果,建议使用①和②,但是方法①不能以batch形式输入,速度较慢。

关注下方《学姐带你玩AI》

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/天景科技苑/article/detail/828646
推荐阅读
相关标签