当前位置:   article > 正文

Pytorch简单实现seq2seq+Attention机器人问答_pytorch seq2seq实现问答

pytorch seq2seq实现问答

一、准备数据

1.seq_example代表问题,seq_answer代表答案,数据内容如下所示:

  1. seq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "你有父母吗"]
  2. seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "我没有父母"]

2.将数据进行jieba分词并加入索引index,其中SOS代表单词开头,EOS代表单词结尾,PAD补全,数据如下:

{'你': 3, '认识': 4, '我': 5, '吗': 6, '住': 7, '在': 8, '哪里': 9, '知道': 10, '的': 11, '名字': 12, '是': 13, '谁': 14, '会': 15, '唱歌': 16, '有': 17, '父母': 18, '当然': 19, '成都': 20, '不': 21, '机器人': 22, '不会': 23, '没有': 24, 'PAD': 0, 'SOS': 1, 'EOS': 2}

3. 最后将seq_example与seq_answer分词后使用索引表示

二、模型构建

1.encoder

采用双向LSTM处理输入向量,代码如下:

  1. class lstm_encoder(nn.Module):
  2. def __init__(self):
  3. super(lstm_encoder, self).__init__()
  4. # 双向LSTM
  5. self.encoder = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)
  6. def forward(self, embedding_input):
  7. encoder_output, (encoder_h_n, encoder_c_n) = self.encoder(embedding_input)
  8. # 拼接前向和后向最后一个隐层
  9. encoder_h_n = torch.cat([encoder_h_n[0], encoder_h_n[1]], dim=1)
  10. encoder_c_n = torch.cat([encoder_c_n[0], encoder_c_n[1]], dim=1)
  11. return encoder_output, encoder_h_n.unsqueeze(0), encoder_c_n.unsqueeze(0)

2.decoder + Attention

decoder采用单向LSTM并加入Attention机制,即将decoder输出与encoder输出通过Atention拼接后进入全连接层做预测,Attention机制采用的General方式,具体过程如下所示:

 

 代码如下:

  1. class lstm_decoder(nn.Module):
  2. def __init__(self):
  3. super(lstm_decoder, self).__init__()
  4. # 单向LSTM
  5. self.decoder = nn.LSTM(embedding_size, n_hidden * 2, 1)
  6. # attention参数
  7. self.att_weight = nn.Linear(n_hidden * 2, n_hidden * 2)
  8. # attention_joint参数
  9. self.att_joint = nn.Linear(n_hidden * 4, n_hidden * 2)
  10. # 定义全连接层
  11. self.fc = nn.Linear(n_hidden * 2, num_classes)
  12. def forward(self, input_x, encoder_output, hn, cn):
  13. decoder_output, (decoder_h_n, decoder_c_n) = self.decoder(input_x, (hn, cn))
  14. decoder_output = decoder_output.permute(1, 0, 2)
  15. encoder_output = encoder_output.permute(1, 0, 2)
  16. decoder_output_att = self.att_weight(encoder_output)
  17. decoder_output_att = decoder_output_att.permute(0, 2, 1)
  18. # 计算分数score
  19. decoder_output_score = decoder_output.bmm(decoder_output_att)
  20. # 计算权重at
  21. at = nn.functional.softmax(decoder_output_score, dim=2)
  22. # 计算新的context向量ct
  23. ct = at.bmm(encoder_output)
  24. # 拼接ct和decoder_ht
  25. ht_joint = torch.cat((ct, decoder_output), dim=2)
  26. fc_joint = torch.tanh(self.att_joint(ht_joint))
  27. fc_out = self.fc(fc_joint)
  28. return fc_out, decoder_h_n, decoder_c_n

三、具体代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import jieba
  5. import os
  6. seq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "你有父母吗"]
  7. seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "我没有父母"]
  8. # 所有词
  9. example_cut = []
  10. answer_cut = []
  11. word_all = []
  12. # 分词
  13. for i in seq_example:
  14. example_cut.append(list(jieba.cut(i)))
  15. for i in seq_answer:
  16. answer_cut.append(list(jieba.cut(i)))
  17. # 所有词
  18. for i in example_cut + answer_cut:
  19. for word in i:
  20. if word not in word_all:
  21. word_all.append(word)
  22. # 词语索引表
  23. word2index = {w: i+3 for i, w in enumerate(word_all)}
  24. # 补全
  25. word2index['PAD'] = 0
  26. # 句子开始
  27. word2index['SOS'] = 1
  28. # 句子结束
  29. word2index['EOS'] = 2
  30. index2word = {value: key for key, value in word2index.items()}
  31. # 一些参数
  32. vocab_size = len(word2index)
  33. seq_length = max([len(i) for i in example_cut + answer_cut]) + 1
  34. embedding_size = 5
  35. num_classes = vocab_size
  36. n_hidden = 10
  37. # 将句子用索引表示
  38. def make_data(seq_list):
  39. result = []
  40. for word in seq_list:
  41. seq_index = [word2index[i] for i in word]
  42. if len(seq_index) < seq_length:
  43. seq_index += [0] * (seq_length - len(seq_index))
  44. result.append(seq_index)
  45. return result
  46. encoder_input = make_data(example_cut)
  47. decoder_input = make_data([['SOS'] + i for i in answer_cut])
  48. decoder_target = make_data([i + ['EOS'] for i in answer_cut])
  49. # 训练数据
  50. encoder_input, decoder_input, decoder_target = torch.LongTensor(encoder_input), torch.LongTensor(decoder_input), torch.LongTensor(decoder_target)
  51. # 建立encoder模型
  52. class lstm_encoder(nn.Module):
  53. def __init__(self):
  54. super(lstm_encoder, self).__init__()
  55. # 双向LSTM
  56. self.encoder = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)
  57. def forward(self, embedding_input):
  58. encoder_output, (encoder_h_n, encoder_c_n) = self.encoder(embedding_input)
  59. # 拼接前向和后向最后一个隐层
  60. encoder_h_n = torch.cat([encoder_h_n[0], encoder_h_n[1]], dim=1)
  61. encoder_c_n = torch.cat([encoder_c_n[0], encoder_c_n[1]], dim=1)
  62. return encoder_output, encoder_h_n.unsqueeze(0), encoder_c_n.unsqueeze(0)
  63. # 建立attention_decoder模型
  64. class lstm_decoder(nn.Module):
  65. def __init__(self):
  66. super(lstm_decoder, self).__init__()
  67. # 单向LSTM
  68. self.decoder = nn.LSTM(embedding_size, n_hidden * 2, 1)
  69. # attention参数
  70. self.att_weight = nn.Linear(n_hidden * 2, n_hidden * 2)
  71. # attention_joint参数
  72. self.att_joint = nn.Linear(n_hidden * 4, n_hidden * 2)
  73. # 定义全连接层
  74. self.fc = nn.Linear(n_hidden * 2, num_classes)
  75. def forward(self, input_x, encoder_output, hn, cn):
  76. decoder_output, (decoder_h_n, decoder_c_n) = self.decoder(input_x, (hn, cn))
  77. decoder_output = decoder_output.permute(1, 0, 2)
  78. encoder_output = encoder_output.permute(1, 0, 2)
  79. decoder_output_att = self.att_weight(encoder_output)
  80. decoder_output_att = decoder_output_att.permute(0, 2, 1)
  81. # 计算分数score
  82. decoder_output_score = decoder_output.bmm(decoder_output_att)
  83. # 计算权重at
  84. at = nn.functional.softmax(decoder_output_score, dim=2)
  85. # 计算新的context向量ct
  86. ct = at.bmm(encoder_output)
  87. # 拼接ct和decoder_ht
  88. ht_joint = torch.cat((ct, decoder_output), dim=2)
  89. fc_joint = torch.tanh(self.att_joint(ht_joint))
  90. fc_out = self.fc(fc_joint)
  91. return fc_out, decoder_h_n, decoder_c_n
  92. class seq2seq(nn.Module):
  93. def __init__(self):
  94. super(seq2seq, self).__init__()
  95. self.word_vec = nn.Embedding(vocab_size, embedding_size)
  96. # encoder
  97. self.seq2seq_encoder = lstm_encoder()
  98. # decoder
  99. self.seq2seq_decoder = lstm_decoder()
  100. def forward(self, encoder_input, decoder_input, inference_threshold=0):
  101. embedding_encoder_input = self.word_vec(encoder_input)
  102. embedding_decoder_input = self.word_vec(decoder_input)
  103. # 调换第一维和第二维度
  104. embedding_encoder_input = embedding_encoder_input.permute(1, 0, 2)
  105. embedding_decoder_input = embedding_decoder_input.permute(1, 0, 2)
  106. # 编码器
  107. encoder_output, h_n, c_n = self.seq2seq_encoder(embedding_encoder_input)
  108. # 判断为训练还是预测
  109. if inference_threshold:
  110. # 解码器
  111. decoder_output, h_n, c_n = self.seq2seq_decoder(embedding_decoder_input, encoder_output, h_n, c_n)
  112. return decoder_output
  113. else:
  114. # 创建outputs张量存储Decoder的输出
  115. outputs = []
  116. for i in range(seq_length):
  117. decoder_output, h_n, c_n = self.seq2seq_decoder(embedding_decoder_input, encoder_output, h_n, c_n)
  118. decoder_x = torch.max(decoder_output.reshape(-1, 25), dim=1)[1].item()
  119. if decoder_x in [0, 2]:
  120. return outputs
  121. outputs.append(decoder_x)
  122. embedding_decoder_input = self.word_vec(torch.LongTensor([[decoder_x]]))
  123. embedding_decoder_input = embedding_decoder_input.permute(1, 0, 2)
  124. return outputs
  125. model = seq2seq()
  126. print(model)
  127. criterion = nn.CrossEntropyLoss()
  128. optimizer = optim.SGD(model.parameters(), lr=0.05)
  129. # 判断是否有模型文件
  130. if os.path.exists("./seq2seqModel.pkl"):
  131. model.load_state_dict(torch.load('./seq2seqModel.pkl'))
  132. else:
  133. # 训练
  134. model.train()
  135. for epoch in range(10000):
  136. pred = model(encoder_input, decoder_input, 1)
  137. loss = criterion(pred.reshape(-1, 25), decoder_target.view(-1))
  138. optimizer.zero_grad()
  139. loss.backward()
  140. optimizer.step()
  141. if (epoch + 1) % 1000 == 0:
  142. print("Epoch: %d, loss: %.5f " % (epoch + 1, loss))
  143. # 保存模型
  144. torch.save(model.state_dict(), './seq2seqModel.pkl')
  145. # 测试
  146. model.eval()
  147. question_text = '你住在哪里'
  148. question_cut = list(jieba.cut(question_text))
  149. encoder_x = make_data([question_cut])
  150. decoder_x = [[word2index['SOS']]]
  151. encoder_x, decoder_x = torch.LongTensor(encoder_x), torch.LongTensor(decoder_x)
  152. out = model(encoder_x, decoder_x)
  153. answer = ''
  154. for i in out:
  155. answer += index2word[i]
  156. print('问题:', question_text)
  157. print('回答:', answer)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/536087
推荐阅读
相关标签
  

闽ICP备14008679号