当前位置:   article > 正文

NLP入门 - 基于Word Embedding + LSTM的古诗生成器_古诗续写生成器

古诗续写生成器

一共实现三个功能:

1. 续写五言诗

2. 续写七言诗

3. 写五言藏头诗

之前用这个做Intro to Computer Science的期末项目折腾太久,不想赘述,内容介绍及实现方法可参考期末presentation的slides:

https://docs.google.com/presentation/d/1DFy3VwAETeqK0QFsokeBpDwyVkMavOjpckQKpc6XPTI/edit#slide=id.gb037c6e317_2_312

训练数据来源:

https://github.com/BraveY/AI-with-code/tree/master/Automatic-poem-writing

 

五言诗数据预处理(七言类似,不再贴代码):

  1. import os
  2. import numpy as np
  3. import numpy as np
  4. from copy import deepcopy
  5. all_data = np.load("tang.npz", allow_pickle=True)
  6. dataset = all_data["data"]
  7. word2ix = all_data["word2ix"].item()
  8. ix2word = all_data["ix2word"].item()
  9. print(len(word2ix))
  10. print(dataset.shape)
  11. l = dataset.tolist()
  12. poems = [[None for i in range(125)] for j in range(57580)]
  13. for i in range(57580):
  14. for j in range(125):
  15. poems[i][j] = ix2word[l[i][j]]
  16. data = list()
  17. for i in range(57580):
  18. s = 0
  19. e = 0
  20. for ele in poems[i]:
  21. if ele == '<START>':
  22. s += 1
  23. if ele == '<EOP>':
  24. e += 1
  25. if s == 1 and e == 1:
  26. st = poems[i].index('<START>')
  27. ed = poems[i].index('<EOP>')
  28. if (ed - st - 1) % 6 == 0 and poems[i][st + 6] == ',' and (ed - st - 1) == 48:
  29. # 五言诗,每诗4句
  30. line = poems[i][st + 1:ed]
  31. for j in range(0, len(line), 24):
  32. cur = line[j:j + 24]
  33. if cur[5] == ',' and cur[11] == '。' and cur[17] == ',' and cur[23] == '。':
  34. data.append(cur)
  35. # for ele in data:
  36. # print(ele)
  37. print(len(data))
  38. t = list()
  39. for i, line in enumerate(data):
  40. print(i, line)
  41. words = line[0:5] + line[6:11] + line[12:17] + line[18:23]
  42. nums = [word2ix[words[i]] for i in range(len(words))]
  43. t.append(nums)
  44. t = np.array(t)
  45. print(t.shape)
  46. t = t[77:]
  47. labels = deepcopy(t)
  48. for i in range(29696):
  49. for j in range(20):
  50. if j < 19:
  51. labels[i][j] = labels[i][j+1]
  52. else:
  53. labels[i][j] = 0
  54. np.save("train_x.npy", t)
  55. np.save("train_y.npy", labels)

 

 

五言诗训练(七言类似,不再贴代码):

  1. # -*- coding: utf-8 -*-
  2. """54.ipynb
  3. Automatically generated by Colaboratory.
  4. Original file is located at
  5. https://colab.research.google.com/drive/1ZdPq71-40K5tGK__OPcUe84mfopWwJP-
  6. """
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. import torch
  10. import torch.nn as nn
  11. import torch.optim as optim
  12. import torch.nn.functional as F
  13. from copy import deepcopy
  14. VOCAB_SIZE = 8293
  15. all_data = np.load("/content/drive/My Drive/Deep Learning/LSTM - Chinese Poem Writer/tang.npz", allow_pickle=True)
  16. word2ix = all_data["word2ix"].item()
  17. ix2word = all_data["ix2word"].item()
  18. class Model(nn.Module):
  19. def __init__(self, vocab_size, embedding_dim, hidden_dim):
  20. super(Model, self).__init__()
  21. self.hidden_dim = hidden_dim
  22. self.embedding = nn.Embedding(vocab_size, embedding_dim)
  23. self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=3)
  24. self.out = nn.Linear(self.hidden_dim, vocab_size)
  25. def forward(self, input, hidden=None):
  26. seq_len, batch_size = input.shape
  27. if hidden is None:
  28. h_0 = torch.zeros(3, batch_size, self.hidden_dim).cuda()
  29. c_0 = torch.zeros(3, batch_size, self.hidden_dim).cuda()
  30. else:
  31. h_0, c_0 = hidden
  32. embed = self.embedding(input)
  33. # embed_size = (seq_len, batch_size, embedding_dim)
  34. output, hidden = self.lstm(embed, (h_0, c_0))
  35. # output_size = (seq_len, batch_size, hidden_dim)
  36. output = output.reshape(seq_len * batch_size, -1)
  37. # output_size = (seq_len * batch_size, hidden_dim)
  38. output = self.out(output)
  39. # output_size = (seq_len * batch_size, vocab_size)
  40. return output, hidden
  41. train_x = np.load("/content/drive/My Drive/Deep Learning/LSTM - Chinese Poem Writer/train_x_54.npy")
  42. train_y = np.load("/content/drive/My Drive/Deep Learning/LSTM - Chinese Poem Writer/train_y_54.npy")
  43. id = 253
  44. print(train_x.shape)
  45. print(train_y.shape)
  46. print(train_x[id])
  47. print(train_y[id])
  48. a = [None for i in range(20)]
  49. b = [None for j in range(20)]
  50. for j in range(20):
  51. a[j] = ix2word[train_x[id][j]]
  52. b[j] = ix2word[train_y[id][j]]
  53. b[19] = 'Null'
  54. print(a)
  55. print(b)
  56. class PoemWriter():
  57. def __init__(self):
  58. self.model = Model(VOCAB_SIZE, 64, 128)
  59. self.lr = 1e-3
  60. self.epochs = 0
  61. self.seq_len = 20
  62. self.batch_size = 128
  63. self.opt = optim.Adam(self.model.parameters(), lr=self.lr)
  64. def train(self, epochs):
  65. self.epochs = epochs
  66. self.model = self.model.cuda()
  67. criterion = nn.CrossEntropyLoss()
  68. all_losses = []
  69. for epoch in range(self.epochs):
  70. print("Epoch:", epoch + 1)
  71. total_loss = 0
  72. for i in range(0, train_x.shape[0], self.batch_size):
  73. # print(i, i + self.batch_size)
  74. cur_x = torch.from_numpy(train_x[i:i + self.batch_size])
  75. cur_y = torch.from_numpy(train_y[i:i + self.batch_size])
  76. cur_x = torch.transpose(cur_x, 0, 1).long()
  77. cur_y = torch.transpose(cur_y, 0, 1).long()
  78. cur_y = cur_y.reshape(self.seq_len * self.batch_size, -1).squeeze(1)
  79. cur_x, cur_y = cur_x.cuda(), cur_y.cuda()
  80. pred, _ = self.model.forward(cur_x)
  81. loss = criterion(pred, cur_y)
  82. self.opt.zero_grad()
  83. loss.backward()
  84. self.opt.step()
  85. total_loss += loss.item()
  86. print("Loss:", total_loss)
  87. all_losses.append(total_loss)
  88. self.model = self.model.cpu()
  89. plt.plot(all_losses, 'r')
  90. def write(self, string="空山新雨后"):
  91. inp = []
  92. for c in string:
  93. inp.append(word2ix[c])
  94. inp = torch.from_numpy(np.array(inp)).unsqueeze(1).long()
  95. # print(inp.shape, inp)
  96. tmp = torch.zeros(15, 1).long()
  97. inp = torch.cat([inp, tmp], dim=0)
  98. inp = inp.cuda()
  99. self.model = self.model.cuda()
  100. for tim in range(15):
  101. pred, _ = self.model.forward(inp)
  102. pred = torch.argmax(pred, dim=1)
  103. inp[tim + 5] = pred[tim + 4]
  104. ans = list()
  105. for i in range(20):
  106. ans.append(ix2word[inp[i].item()])
  107. out = ""
  108. for i in range(20):
  109. out += ans[i]
  110. if i == 4 or i == 14:
  111. out += ','
  112. if i == 9 or i == 19:
  113. out += '。'
  114. print(out)
  115. torch.cuda.get_device_name()
  116. Waner = PoemWriter()
  117. Waner.train(1000)
  118. Waner.write()
  119. torch.save(Waner.model, '/content/drive/My Drive/Deep Learning/LSTM - Chinese Poem Writer/model_54.pkl')
  120. while True:
  121. st = input()
  122. if st == "-1":
  123. break
  124. Waner.write(st)

最终模型:

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import torch.nn.functional as F
  7. from copy import deepcopy
  8. VOCAB_SIZE = 8293
  9. all_data = np.load("tang.npz", allow_pickle=True)
  10. word2ix = all_data["word2ix"].item()
  11. ix2word = all_data["ix2word"].item()
  12. class Model(nn.Module):
  13. def __init__(self, vocab_size, embedding_dim, hidden_dim):
  14. super(Model, self).__init__()
  15. self.hidden_dim = hidden_dim
  16. self.embedding = nn.Embedding(vocab_size, embedding_dim)
  17. self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=3)
  18. self.out = nn.Linear(self.hidden_dim, vocab_size)
  19. def forward(self, input, hidden=None):
  20. seq_len, batch_size = input.shape
  21. if hidden is None:
  22. h_0 = torch.zeros(3, batch_size, self.hidden_dim).cuda()
  23. c_0 = torch.zeros(3, batch_size, self.hidden_dim).cuda()
  24. else:
  25. h_0, c_0 = hidden
  26. embed = self.embedding(input)
  27. # embed_size = (seq_len, batch_size, embedding_dim)
  28. output, hidden = self.lstm(embed, (h_0, c_0))
  29. # output_size = (seq_len, batch_size, hidden_dim)
  30. output = output.reshape(seq_len * batch_size, -1)
  31. # output_size = (seq_len * batch_size, hidden_dim)
  32. output = self.out(output)
  33. # output_size = (seq_len * batch_size, vocab_size)
  34. return output, hidden
  35. class Writer():
  36. def __init__(self):
  37. self.model_74 = torch.load('model_74.pkl')
  38. self.model_54 = torch.load('model_54.pkl')
  39. # self.lr = 1e-3
  40. self.epochs = 0
  41. self.seq_len = 28
  42. self.batch_size = 128
  43. # self.opt = optim.Adam(self.model.parameters(), lr=self.lr)
  44. def write_74(self, string="锦瑟无端五十弦"):
  45. inp = []
  46. for c in string:
  47. inp.append(word2ix[c])
  48. inp = torch.from_numpy(np.array(inp)).unsqueeze(1).long()
  49. # print(inp.shape, inp)
  50. tmp = torch.zeros(21, 1).long()
  51. inp = torch.cat([inp, tmp], dim=0)
  52. inp = inp.cuda()
  53. self.model_74 = self.model_74.cuda()
  54. for tim in range(21):
  55. pred, _ = self.model_74.forward(inp)
  56. pred = torch.argmax(pred, dim=1)
  57. inp[tim + 7] = pred[tim + 6]
  58. ans = list()
  59. for i in range(28):
  60. ans.append(ix2word[inp[i].item()])
  61. out = ""
  62. for i in range(28):
  63. out += ans[i]
  64. if i == 6 or i == 20:
  65. out += ','
  66. if i == 13 or i == 27:
  67. out += '。'
  68. return out
  69. def write_54(self, string="空山新雨后"):
  70. inp = []
  71. for c in string:
  72. inp.append(word2ix[c])
  73. inp = torch.from_numpy(np.array(inp)).unsqueeze(1).long()
  74. # print(inp.shape, inp)
  75. tmp = torch.zeros(15, 1).long()
  76. inp = torch.cat([inp, tmp], dim=0)
  77. inp = inp.cuda()
  78. self.model_54 = self.model_54.cuda()
  79. for tim in range(15):
  80. pred, _ = self.model_54.forward(inp)
  81. pred = torch.argmax(pred, dim=1)
  82. inp[tim + 5] = pred[tim + 4]
  83. ans = list()
  84. for i in range(20):
  85. ans.append(ix2word[inp[i].item()])
  86. out = ""
  87. for i in range(20):
  88. out += ans[i]
  89. if i == 4 or i == 14:
  90. out += ','
  91. if i == 9 or i == 19:
  92. out += '。'
  93. return out
  94. def acrostic(self, string="为尔心悦"):
  95. inp = torch.zeros(20, 1).long()
  96. inp = inp.cuda()
  97. self.model_54 = self.model_54.cuda()
  98. for i in range(20):
  99. if i == 0 or i == 5 or i == 10 or i == 15:
  100. inp[i] = word2ix[string[i // 5]]
  101. else:
  102. inp[i] = pred[i - 1]
  103. pred, _ = self.model_54.forward(inp)
  104. pred = torch.argmax(pred, dim=1)
  105. ans = list()
  106. for i in range(20):
  107. ans.append(ix2word[inp[i].item()])
  108. out = ""
  109. for i in range(20):
  110. out += ans[i]
  111. if i == 4 or i == 14:
  112. out += ','
  113. if i == 9 or i == 19:
  114. out += '。'
  115. return out
  116. def task(self, string):
  117. l = len(string)
  118. try:
  119. if l == 4:
  120. return self.acrostic(string)
  121. elif l == 5:
  122. return self.write_54(string)
  123. elif l == 7:
  124. return self.write_74(string)
  125. except:
  126. return "I don't know how to write one...QAQ"
  127. '''
  128. a = torch.ones(5)
  129. a[1] = 4
  130. print(a)
  131. '''

测试样例:

 

虽然韵脚未专门处理,大多不太对,但是能学到一些意象并营造一定意境,如果用更好的word embedding可能会有更好的performance(目前的embedding为pytorch随机生成)。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/514361
推荐阅读
相关标签
  

闽ICP备14008679号