当前位置:   article > 正文

利用GPT2生成莎士比亚写作风格的文本(python实现)_训练有个性的文字写作模型

训练有个性的文字写作模型

一:原理

在此仅仅是简单介绍,还需要读者对self-attention、Transformer、GPT有一定的知识储备。

原始的 transformer 论文引入了两种类型的 transformer 模块,分别是:编码器模块和解码器模块。原始 transformer 论文中的编码器模块可以接受长度不超过最大序列长度(如 512 个单词)的输入。如果序列长度小于该限制,我们就在其后填入预先定义的空白单词(如<pad>)。解码器模块,它与编码器模块在架构上有一点小差异——加入了一层使得它可以重点关注编码器输出的某一片段,也就是下图中的编码器-解码器自注意力(encoder-decoder self-attention)层。

 解码器在自注意力(self-attention)层上还有一个关键的差异:它将后面的单词掩盖掉了。但并不像 BERT 一样将它们替换成特殊定义的单词<mask>,而是在自注意力计算的时候屏蔽了来自当前计算位置右边所有单词的信息。GPT-2 使用的带掩模的自注意力(masked self-attention)模块很重要。普通的自注意力模块允许一个位置看到它右侧单词的信息(如下左图),而带掩模的自注意力模块则不允许这么做(如下右图)。

 GPT2模型就用了这种只包含编码器(decoder-only)的模块。想要运行一个训练好的 GPT-2 模型,最简单的方法就是让它自己随机工作(从技术上说,叫做生成无条件样本)。换句话说,我们也可以给它一点提示,让它说一些关于特定主题的话(即生成交互式条件样本)。

在随机情况下,我们只简单地提供一个预先定义好的起始单词(训练好的模型使用「|endoftext|」作为它的起始单词,不妨将其称为<s>),然后让它自己生成文字。此时,模型的输入只有一个单词,所以只有这个单词的路径是活跃的。单词经过层层处理,最终得到一个向量。向量可以对于词汇表的每个单词计算一个概率(词汇表是模型能「说出」的所有单词,GPT-2 的词汇表中有 50000 个单词)。在本例中,我们选择概率最高的单词「The」作为下一个单词。但有时这样会出问题——就像如果我们持续点击输入法推荐单词的第一个,它可能会陷入推荐同一个词的循环中,只有你点击第二或第三个推荐词,才能跳出这种循环。同样的,GPT-2 也有一个叫做「top-k」的参数,模型会从概率前 k 大的单词中抽样选取下一个单词。显然,在之前的情况下,top-k = 1。

接下来,我们将输出的单词添加在输入序列的尾部构建新的输入序列,让模型进行下一步的预测:

 

请注意,第二个单词的路径是当前唯一活跃的路径了。GPT-2 的每一层都保留了它们对第一个单词的解释,并且将运用这些信息处理第二个单词(具体将在下面一节对自注意力机制的讲解中详述),GPT-2 不会根据第二个单词重新解释第一个单词。 

输入数据是莎士比亚全集

二:python实现

  1. import random
  2. #导入随机函数,用于从预测结果中随机选取概率较大的前k个作为预测结果
  3. def select_top_k(predictions, k=10):
  4. predicted_index = random.choice(
  5. predictions[0, -1, :].sort(descending=True)[1][:10]).item()
  6. return predicted_index
  7. #定义函数select_top_k,从预测结果中选取概率较大的前k个
  8. # 过滤警告信息
  9. import warnings
  10. warnings.filterwarnings('ignore')
  11. #进行日志记录
  12. import logging
  13. logging.basicConfig(level=logging.INFO)
  14. import torch
  15. from transformers import GPT2Tokenizer
  16. # 载入预训练模型的分词器
  17. tokenizer = GPT2Tokenizer.from_pretrained('gpt2')#载入预训练模型的分词器
  18. # 使用 GPT2Tokenizer 对输入进行编码
  19. text = "From fairest creatures we desire increase,That thereby beauty’s rose might never die,But as the riper should by time decease,His tender heir might bear his memory:But thou contracted to thine own bright eyes,Feed’st thy light’s flame with self-substantial fuel,Making a famine where abundance lies,Thyself thy foe, to thy sweet self too cruel:hou that art now the world’s fresh ornament,And only herald to the gaudy spring,Within thine own bud buriest thy content,And, tender churl, mak’st waste in niggarding:Pity the world, or else this glutton be,To eat the world’s due, by the grave and thee."
  20. # with open('莎士比亚.txt','r',encoding='utf-8') as f:
  21. # text=f.read()
  22. indexed_tokens = tokenizer.encode(text)
  23. tokens_tensor = torch.tensor([indexed_tokens])
  24. tokens_tensor.shape
  25. from transformers import AutoTokenizer,AutoModelWithLMHead
  26. tokenizer=AutoTokenizer.from_pretrained('gpt2')
  27. mmodel=AutoModelWithLMHead.from_pretrained('gpt2')
  28. from transformers import GPT2LMHeadModel
  29. # 读取 GPT-2 预训练模型
  30. model = GPT2LMHeadModel.from_pretrained("./gpt2")
  31. model.eval()
  32. total_predicted_text = text
  33. n = 100 # 预测过程的循环次数
  34. for _ in range(n):
  35. with torch.no_grad():
  36. outputs = model(tokens_tensor)
  37. predictions = outputs[0]
  38. predicted_index = select_top_k(predictions, k=10)
  39. predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
  40. total_predicted_text += tokenizer.decode(predicted_index)
  41. if '<|endoftext|>' in total_predicted_text:
  42. # 如果出现文本结束标志,就结束文本生成
  43. break
  44. indexed_tokens += [predicted_index]
  45. tokens_tensor = torch.tensor([indexed_tokens])
  46. print('生成:::',total_predicted_text)

看看生成的结果:

动动您的小手,求点赞,谢谢啦!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/413194
推荐阅读
相关标签
  

闽ICP备14008679号