当前位置:   article > 正文

lstm古诗生成-pytorch_lstm生成古诗

lstm生成古诗

本文为RNN做古诗生成的一个小demo,只要是为了完成课上的作业(由于训练比较慢,所以周期仅设置为3,大一点性能可能会更好),如有需要可以在这基础之上进行加工,数据集没办法上传,如有需要,可以私信我。

LSTM:

如上图所示LSTM神经元存在两个状态向量:h(t)和c(t)(可将h(t)视为短期状态,c(t)视为长期状态) 首先,将当前输入向量x(t)和先前的短期状态h(t-1)馈入四个不同的全连接层(FC)。它们都有不同的目的:

主要层是输出g(t)的层:它通常的作用是分析当前输入x(t)和先前(短期)状态 h(t-1),得到本时间步的信息。

遗忘门(由f(t)控制):控制长期状态的哪些部分应当被删除。

输入门(由i(t)控制):控制应将g(t)的哪些部分添加到长期状态。

输出门(由o(t)控制):控制应在此时间步长读取长期状态的哪些部分并输出 到h(t)和y(t)。

如图1,LSTM神经元运用了三个sigmoid激活函数和一个tanh激活函数,

Tanh 作用在于帮助调节流经网络的值,使得数值始终限制在 -1 和 1 之间。
Sigmoid 激活函数与 tanh 函数类似,不同之处在于 sigmoid 是把值压缩到0~1 这样的设置有助于更新或忘记信息,可将其理解为比例(任何数乘以 0 都得 0,这部分信息就会剔除掉;同样的,任何数乘以 1 都得到它本身,这部分信息就会完美地保存下来)因记忆能力有限,记住重要的,忘记不重要的。
例子:以输入门为例,首先输入x(t)和先前(短期)状态 h(t-1),得到本时间步的信息向量g(t) = (g1(t),g2(t),g3(t)……gn(t))(其中n个神经元的个数,g1(t)取值范围为(-1,1)),然后与向量i(t)=(i1(t),i2(t),i3(t)……in(t))(ii(t)取值范围为(0,1))对应元素相乘,得到向量(g1(t)*i1(t), g2(t)*i2(t)……gn(t)*in(t)),即本时间步有用信息,然后把他加上长期记忆c(t-1)中进行保存。

LSM关键的思想是网络可以学习长期状态下存储的内容、丢弃的内容以及从中读取的内容。当长期状态c(t-1)从左到右遍历网络时,可以看到它首先经过一个遗 忘门,丢掉了一些记忆,然后通过加法操作添加了一些新的记忆(由输入门选择的记忆)。结果c(t)直接送出来,无须任何进一步的转换。因此,在每个时间步长中,都会 丢掉一些记忆,并添加一些记忆。此外,在加法运算之后,长期状态被复制并通过tanh函数传输,然后结果被输出门滤波。这将产生短期状态h(t)(等于该时间步长的单元输出 y(t))。

 原理:

本文使用LSTM生成古诗,那么RNN是怎么用作我们的文本生成呢?话不多说,其实用RNN来生成的思想很简单, 就是将前一个字进行词嵌入,后一个字作为标签,将这个组合输入到RNN的网络里面等待训练拟合之后,再用一个引导词,训练出它的预测结果,再用其预测结果,来训练下一个词,循环往复,从而实现RNN生成文本的效果.
 

main.py

  1. import numpy as np
  2. import collections
  3. import torch
  4. from torch.autograd import Variable
  5. import torch.optim as optim
  6. import rnn
  7. start_token = 'G'
  8. end_token = 'E'
  9. batch_size = 64
  10. def process_poems1(file_name):
  11. """
  12. :param file_name:
  13. :return: poems_vector have two dimmention ,first is the poem, the second is the word_index
  14. e.g. [[1,2,3,4,5,6,7,8,9,10],[9,6,3,8,5,2,7,4,1]]
  15. """
  16. poems = []
  17. i = 1
  18. with open(file_name, "r", encoding='utf-8', ) as f:
  19. for line in f.readlines():
  20. try:
  21. i = i+1
  22. title, content = line.strip().split(':')
  23. # content = content.replace(' ', '').replace(',','').replace('。','')
  24. content = content.replace(' ', '')
  25. if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
  26. start_token in content or end_token in content:
  27. continue
  28. if len(content) < 5 or len(content) > 80:
  29. continue
  30. content = start_token + content + end_token
  31. poems.append(content)
  32. except ValueError as e:
  33. print(line)
  34. print(i)
  35. print("error")
  36. pass
  37. # 按诗的字数排序
  38. poems = sorted(poems, key=lambda line: len(line))
  39. # print(poems)
  40. # 统计每个字出现次数
  41. all_words = []
  42. j = 0
  43. for poem in poems:
  44. all_words += [word for word in poem] # 数据连接
  45. counter = collections.Counter(all_words) # 统计词和词频。
  46. count_pairs = sorted(counter.items(), key=lambda x: -x[1]) # d.items() 以列表的形式返回可遍历的元组数组 逆序排序
  47. words, _ = zip(*count_pairs) # zip(*) 可理解为解压,返回二维矩阵式
  48. words = words[:len(words)] + (' ',) #(‘ ’,) 为一个元素的元祖
  49. word_int_map = dict(zip(words, range(len(words))))
  50. poems_vector = [list(map(word_int_map.get, poem)) for poem in poems] # 第一位为一个函数,后一位为一个迭代器
  51. return poems_vector, word_int_map, words # 诗句的向量表示,单词映射表,单词表
  52. def process_poems2(file_name):
  53. """
  54. :param file_name:
  55. :return: poems_vector have tow dimmention ,first is the poem, the second is the word_index
  56. e.g. [[1,2,3,4,5,6,7,8,9,10],[9,6,3,8,5,2,7,4,1]]
  57. """
  58. poems = []
  59. with open(file_name, "r", encoding='utf-8', ) as f:
  60. # content = ''
  61. for line in f.readlines():
  62. try:
  63. line = line.strip()
  64. if line:
  65. content = line.replace(' '' ', '').replace(',','').replace('。','')
  66. if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
  67. start_token in content or end_token in content:
  68. continue
  69. if len(content) < 5 or len(content) > 80:
  70. continue
  71. # print(content)
  72. content = start_token + content + end_token
  73. poems.append(content)
  74. # content = ''
  75. except ValueError as e:
  76. # print("error")
  77. pass
  78. # 按诗的字数排序
  79. poems = sorted(poems, key=lambda line: len(line))
  80. # print(poems)
  81. # 统计每个字出现次数
  82. all_words = []
  83. for poem in poems:
  84. all_words += [word for word in poem]
  85. counter = collections.Counter(all_words) # 统计词和词频。
  86. count_pairs = sorted(counter.items(), key=lambda x: -x[1]) # 排序
  87. words, _ = zip(*count_pairs)
  88. words = words[:len(words)] + (' ',)
  89. word_int_map = dict(zip(words, range(len(words))))
  90. poems_vector = [list(map(word_int_map.get, poem)) for poem in poems]
  91. return poems_vector, word_int_map, words
  92. def generate_batch(batch_size, poems_vec, word_to_int):
  93. #生成训练数据
  94. n_chunk = len(poems_vec) // batch_size #34813/100 = 348 古诗的向量表示
  95. x_batches = []
  96. y_batches = []
  97. for i in range(n_chunk):
  98. start_index = i * batch_size
  99. end_index = start_index + batch_size
  100. x_data = poems_vec[start_index:end_index]
  101. y_data = []
  102. for row in x_data:
  103. y = row[1:]
  104. y.append(row[-1])
  105. y_data.append(y)
  106. """
  107. x_data y_data
  108. [6,2,4,6,9] [2,4,6,9,9] 文本生成,所以用后面一位数据做label
  109. [1,4,2,8,5] [4,2,8,5,5]
  110. """
  111. # print(x_data[0])
  112. # print(y_data[0])
  113. # exit(0)
  114. x_batches.append(x_data)
  115. y_batches.append(y_data)
  116. return x_batches, y_batches
  117. def run_training():
  118. # 处理数据集
  119. # poems_vector, word_to_int, vocabularies = process_poems2('./tangshi.txt')
  120. poems_vector, word_to_int, vocabularies = process_poems1('./poems.txt')
  121. # 生成batch
  122. print("finish loadding data")
  123. BATCH_SIZE = 100
  124. torch.manual_seed(5)
  125. word_embedding = rnn.word_embedding( vocab_length= len(word_to_int) + 1 , embedding_dim= 100) #6123 x 100
  126. #print(word_embedding.shape)
  127. rnn_model = rnn.RNN_model(batch_sz = BATCH_SIZE,vocab_len = len(word_to_int) + 1 ,word_embedding = word_embedding ,embedding_dim= 100, lstm_hidden_dim=128)
  128. # optimizer = optim.Adam(rnn_model.parameters(), lr= 0.001)
  129. optimizer=optim.RMSprop(rnn_model.parameters(), lr=0.01)
  130. loss_fun = torch.nn.NLLLoss()
  131. # rnn_model.load_state_dict(torch.load('./poem_generator_rnn')) # if you have already trained your model you can load it by this line.
  132. for epoch in range(3):
  133. batches_inputs, batches_outputs = generate_batch(BATCH_SIZE, poems_vector, word_to_int) #生成训练数据 由batch组成的数组 348
  134. n_chunk = len(batches_inputs)
  135. for batch in range(n_chunk):
  136. batch_x = batches_inputs[batch]
  137. batch_y = batches_outputs[batch] # (batch , time_step)
  138. loss = 0
  139. for index in range(BATCH_SIZE): #batch_size = 100
  140. x = np.array(batch_x[index], dtype = np.int64)
  141. y = np.array(batch_y[index], dtype = np.int64)
  142. x = Variable(torch.from_numpy(np.expand_dims(x,axis=1))) #将数组转换成张量 np.expand_dims扩展数据的形状 x.sahpe = 7x1,
  143. y = Variable(torch.from_numpy(y ))
  144. pre = rnn_model(x) # 7 x 6125
  145. loss += loss_fun(pre , y)
  146. if index == 0:
  147. _, pre = torch.max(pre, dim=1)# pre为张量,tolist转换成列表 
  148. print('prediction', pre.data.tolist()) # the following three line can print the output and the prediction
  149. print('b_y ', y.data.tolist()) # And you need to take a screenshot and then past is to your homework paper.
  150. print('*' * 30)
  151. loss = loss / BATCH_SIZE
  152. print("epoch ",epoch,'batch number',batch,"loss is: ", loss.data.tolist())
  153. optimizer.zero_grad()
  154. loss.backward()
  155. torch.nn.utils.clip_grad_norm(rnn_model.parameters(), 1) # 梯度裁剪 可以预防梯度爆炸,参数的平方和
  156. optimizer.step() #训练参数
  157. if batch % 20 ==0:
  158. torch.save(rnn_model.state_dict(), './poem_generator_rnn')
  159. print("finish save model")
  160. def to_word(predict, vocabs): # 预测的结果转化成汉字
  161. sample = np.argmax(predict)
  162. if sample >= len(vocabs):
  163. sample = len(vocabs) - 1
  164. return vocabs[sample]
  165. def pretty_print_poem(poem): # 令打印的结果更工整
  166. shige=[]
  167. for w in poem:
  168. if w == start_token or w == end_token:
  169. break
  170. shige.append(w)
  171. poem_sentences = poem.split('。')
  172. for s in poem_sentences:
  173. if s != '' and len(s) > 2:
  174. # print(s + '。')
  175. print(s + '。')
  176. def gen_poem(begin_word):
  177. # poems_vector, word_int_map, vocabularies = process_poems2('./tangshi.txt') # use the other dataset to train the network
  178. poems_vector, word_int_map, vocabularies = process_poems1('./poems.txt')
  179. word_embedding = rnn.word_embedding(vocab_length=len(word_int_map) + 1, embedding_dim=100)
  180. rnn_model = rnn.RNN_model(batch_sz=64, vocab_len=len(word_int_map) + 1, word_embedding=word_embedding,
  181. embedding_dim=100, lstm_hidden_dim=128)
  182. rnn_model.load_state_dict(torch.load('./poem_generator_rnn'))
  183. # 指定开始的字
  184. poem = begin_word
  185. word = begin_word
  186. while word != end_token:
  187. input = np.array([word_int_map[w] for w in poem],dtype= np.int64)
  188. input = Variable(torch.from_numpy(input))
  189. output = rnn_model(input, is_test=True)
  190. word = to_word(output.data.tolist(), vocabularies)
  191. poem += word
  192. if len(poem) > 30:
  193. break
  194. return poem
  195. #run_training() # 如果不是训练阶段 ,请注销这一行 。 网络训练时间很长。
  196. pretty_print_poem(gen_poem("日"))
  197. pretty_print_poem(gen_poem("红"))
  198. pretty_print_poem(gen_poem("山"))
  199. pretty_print_poem(gen_poem("夜"))
  200. pretty_print_poem(gen_poem("湖"))
  201. pretty_print_poem(gen_poem("湖"))
  202. pretty_print_poem(gen_poem("湖"))
  203. pretty_print_poem(gen_poem("君"))

rnn.py

  1. import torch.nn as nn
  2. import torch
  3. from torch.autograd import Variable
  4. import torch.nn.functional as F
  5. import numpy as np
  6. def weights_init(m):
  7. classname = m.__class__.__name__ # obtain the class name
  8. if classname.find('Linear') != -1:
  9. weight_shape = list(m.weight.data.size()) #6123 x 128
  10. fan_in = weight_shape[1]
  11. fan_out = weight_shape[0]
  12. w_bound = np.sqrt(6. / (fan_in + fan_out))
  13. m.weight.data.uniform_(-w_bound, w_bound)
  14. m.bias.data.fill_(0)
  15. print("inital linear weight ")
  16. class word_embedding(nn.Module):
  17. def __init__(self,vocab_length , embedding_dim):
  18. super(word_embedding, self).__init__()
  19. w_embeding_random_intial = np.random.uniform(-1,1,size=(vocab_length ,embedding_dim)) #生成服从均匀分布的随机数
  20. self.word_embedding = nn.Embedding(vocab_length,embedding_dim) #创建一个embedding层
  21. self.word_embedding.weight.data.copy_(torch.from_numpy(w_embeding_random_intial))
  22. def forward(self,input_sentence):
  23. """
  24. :param input_sentence: a tensor ,contain several word index.
  25. :return: a tensor ,contain word embedding tensor
  26. """
  27. sen_embed = self.word_embedding(input_sentence)
  28. return sen_embed
  29. class RNN_model(nn.Module):
  30. def __init__(self, batch_sz ,vocab_len ,word_embedding,embedding_dim, lstm_hidden_dim):
  31. super(RNN_model,self).__init__()
  32. self.word_embedding_lookup = word_embedding
  33. self.batch_size = batch_sz
  34. self.vocab_length = vocab_len
  35. self.word_embedding_dim = embedding_dim
  36. self.lstm_dim = lstm_hidden_dim
  37. #########################################
  38. # here you need to define the "self.rnn_lstm" the input size is "embedding_dim" and the output size is "lstm_hidden_dim"
  39. # the lstm should have two layers, and the input and output tensors are provided as (batch, seq, feature)
  40. # ???
  41. self.rnn_lstm = nn.LSTM(input_size=embedding_dim,hidden_size=lstm_hidden_dim, num_layers=2,batch_first=True)
  42. ##########################################
  43. self.fc = nn.Linear(lstm_hidden_dim, vocab_len )
  44. self.apply(weights_init) # call the weights initial function.
  45. self.softmax = nn.LogSoftmax() # the activation function.
  46. # self.tanh = nn.Tanh()
  47. def forward(self,sentence,is_test = False):
  48. batch_input = self.word_embedding_lookup(sentence).view(1,-1,self.word_embedding_dim) # sentence=[7,1] [7x1x100] batch_input=[1,7,100])
  49. # print(batch_input.size()) # print the size of the input
  50. ################################################
  51. # here you need to put the "batch_input" input the self.lstm which is defined before.
  52. # the hidden output should be named as output, the initial hidden state and cell state set to zero.
  53. # ???
  54. #print(batch_input.shape)
  55. output,_ = self.rnn_lstm(batch_input) # 1x7x128
  56. ################################################
  57. out = output.contiguous().view(-1,self.lstm_dim) #1x128
  58. #print(out.shape)
  59. out = F.relu(self.fc(out))
  60. out = self.softmax(out)
  61. if is_test:
  62. prediction = out[ -1, : ].view(1,-1) #[1,6125]
  63. #prediction = torch.max(out,0)
  64. output = prediction
  65. else:
  66. output = out
  67. # print(out)
  68. return output

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号