当前位置:   article > 正文

pytorch代码实践--seq2seq_seq2seq pycharm代码示例

seq2seq pycharm代码示例

加载数据

  1. import os
  2. import sys
  3. import math
  4. from collections import Counter
  5. import numpy as np
  6. import random
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import nltk
  11. def load_data(in_file):
  12. cn = []
  13. en = []
  14. num_examples = 0
  15. with open(in_file, 'r') as f:
  16. for line in f:
  17. line = line.strip().split("\t")
  18. en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
  19. # split chinese sentence into characters
  20. cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
  21. return en, cn
  22. train_file = "nmt/en-cn/train.txt"
  23. dev_file = "nmt/en-cn/dev.txt"
  24. train_en, train_cn = load_data(train_file)
  25. dev_en, dev_cn = load_data(dev_file)
  26. UNK_IDX = 0
  27. PAD_IDX = 1
  28. def build_dict(sentences, max_words=50000):
  29. word_count = Counter()
  30. for sentence in sentences:
  31. for s in sentence:
  32. word_count[s] += 1
  33. ls = word_count.most_common(max_words)
  34. total_words = len(ls) + 2
  35. word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
  36. word_dict["UNK"] = UNK_IDX
  37. word_dict["PAD"] = PAD_IDX
  38. return word_dict, total_words
  39. en_dict, en_total_words = build_dict(train_en)
  40. cn_dict, cn_total_words = build_dict(train_cn)
  41. inv_en_dict = {v: k for k, v in en_dict.items()}
  42. inv_cn_dict = {v: k for k, v in cn_dict.items()}
  43. def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):
  44. '''
  45. Encode the sequences.
  46. '''
  47. length = len(en_sentences)
  48. out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
  49. out_cn_sentences = [[cn_dict.get(w, 0) for w in sent] for sent in cn_sentences]
  50. # sort sentences by english lengths
  51. def len_argsort(seq):
  52. return sorted(range(len(seq)), key=lambda x: len(seq[x]))
  53. # 把中文和英文按照同样的顺序排序
  54. if sort_by_len:
  55. sorted_index = len_argsort(out_en_sentences)
  56. out_en_sentences = [out_en_sentences[i] for i in sorted_index]
  57. out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
  58. return out_en_sentences, out_cn_sentences
  59. train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
  60. dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)
  61. # train_cn[:10]
  62. k = 10000
  63. print(" ".join([inv_cn_dict[i] for i in train_cn[k]]))
  64. print(" ".join([inv_en_dict[i] for i in train_en[k]]))
  65. def get_minibatches(n, minibatch_size, shuffle=True):
  66. idx_list = np.arange(0, n, minibatch_size) # [0, 1, ..., n-1]
  67. if shuffle:
  68. np.random.shuffle(idx_list)
  69. minibatches = []
  70. for idx in idx_list:
  71. minibatches.append(np.arange(idx, min(idx + minibatch_size, n)))
  72. return minibatches
  73. def prepare_data(seqs):
  74. lengths = [len(seq) for seq in seqs]
  75. n_samples = len(seqs)
  76. max_len = np.max(lengths)
  77. x = np.zeros((n_samples, max_len)).astype('int32')
  78. x_lengths = np.array(lengths).astype("int32")
  79. for idx, seq in enumerate(seqs):
  80. x[idx, :lengths[idx]] = seq
  81. return x, x_lengths #x_mask
  82. def gen_examples(en_sentences, cn_sentences, batch_size):
  83. minibatches = get_minibatches(len(en_sentences), batch_size)
  84. all_ex = []
  85. for minibatch in minibatches:
  86. mb_en_sentences = [en_sentences[t] for t in minibatch]
  87. mb_cn_sentences = [cn_sentences[t] for t in minibatch]
  88. mb_x, mb_x_len = prepare_data(mb_en_sentences)
  89. mb_y, mb_y_len = prepare_data(mb_cn_sentences)
  90. all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
  91. return all_ex
  92. batch_size = 64
  93. train_data = gen_examples(train_en, train_cn, batch_size)
  94. random.shuffle(train_data)
  95. dev_data = gen_examples(dev_en, dev_cn, batch_size)

定义模型

  1. class PlainEncoder(nn.Module):
  2. def __init__(self, vocab_size, hidden_size, dropout=0.2):
  3. super(PlainEncoder, self).__init__()
  4. self.embed = nn.Embedding(vocab_size, hidden_size)
  5. self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
  6. self.dropout = nn.Dropout(dropout)
  7. def forward(self, x, lengths):
  8. sorted_len, sorted_idx = lengths.sort(0, descending=True)
  9. x_sorted = x[sorted_idx.long()]
  10. embedded = self.dropout(self.embed(x_sorted))
  11. packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(),
  12. batch_first=True)
  13. packed_out, hid = self.rnn(packed_embedded)
  14. out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
  15. _, original_idx = sorted_idx.sort(0, descending=False)
  16. out = out[original_idx.long()].contiguous()
  17. hid = hid[:, original_idx.long()].contiguous()
  18. return out, hid[[-1]]
  19. class PlainDecoder(nn.Module):
  20. def __init__(self, vocab_size, hidden_size, dropout=0.2):
  21. super(PlainDecoder, self).__init__()
  22. self.embed = nn.Embedding(vocab_size, hidden_size)
  23. self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
  24. self.out = nn.Linear(hidden_size, vocab_size)
  25. self.dropout = nn.Dropout(dropout)
  26. def forward(self, y, y_lengths, hid):
  27. sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
  28. y_sorted = y[sorted_idx.long()]
  29. hid = hid[:, sorted_idx.long()]
  30. y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_size
  31. packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
  32. out, hid = self.rnn(packed_seq, hid)
  33. unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
  34. _, original_idx = sorted_idx.sort(0, descending=False)
  35. output_seq = unpacked[original_idx.long()].contiguous()
  36. # print(output_seq.shape)
  37. hid = hid[:, original_idx.long()].contiguous()
  38. output = F.log_softmax(self.out(output_seq), -1)
  39. return output, hid
  40. class PlainSeq2Seq(nn.Module):
  41. def __init__(self, encoder, decoder):
  42. super(PlainSeq2Seq, self).__init__()
  43. self.encoder = encoder
  44. self.decoder = decoder
  45. def forward(self, x, x_lengths, y, y_lengths):
  46. encoder_out, hid = self.encoder(x, x_lengths)
  47. output, hid = self.decoder(y=y,
  48. y_lengths=y_lengths,
  49. hid=hid)
  50. return output, None
  51. def translate(self, x, x_lengths, y, max_length=10):
  52. encoder_out, hid = self.encoder(x, x_lengths)
  53. preds = []
  54. batch_size = x.shape[0]
  55. attns = []
  56. for i in range(max_length):
  57. output, hid = self.decoder(y=y,
  58. y_lengths=torch.ones(batch_size).long().to(y.device),
  59. hid=hid)
  60. y = output.max(2)[1].view(batch_size, 1)
  61. preds.append(y)
  62. return torch.cat(preds, 1), None
  63. # masked cross entropy loss
  64. class LanguageModelCriterion(nn.Module):
  65. def __init__(self):
  66. super(LanguageModelCriterion, self).__init__()
  67. def forward(self, input, target, mask):
  68. # input: (batch_size * seq_len) * vocab_size
  69. input = input.contiguous().view(-1, input.size(2))
  70. # target: batch_size * 1
  71. target = target.contiguous().view(-1, 1)
  72. mask = mask.contiguous().view(-1, 1)
  73. output = -input.gather(1, target) * mask
  74. output = torch.sum(output) / torch.sum(mask)
  75. return output

训练

  1. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  2. dropout = 0.2
  3. hidden_size = 100
  4. encoder = PlainEncoder(vocab_size=en_total_words,
  5. hidden_size=hidden_size,
  6. dropout=dropout)
  7. decoder = PlainDecoder(vocab_size=cn_total_words,
  8. hidden_size=hidden_size,
  9. dropout=dropout)
  10. model = PlainSeq2Seq(encoder, decoder)
  11. model = model.to(device)
  12. loss_fn = LanguageModelCriterion().to(device)
  13. optimizer = torch.optim.Adam(model.parameters())
  14. def evaluate(model, data):
  15. model.eval()
  16. total_num_words = total_loss = 0.
  17. with torch.no_grad():
  18. for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
  19. mb_x = torch.from_numpy(mb_x).to(device).long()
  20. mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
  21. mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
  22. mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
  23. mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
  24. mb_y_len[mb_y_len<=0] = 1
  25. mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
  26. mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
  27. mb_out_mask = mb_out_mask.float()
  28. loss = loss_fn(mb_pred, mb_output, mb_out_mask)
  29. num_words = torch.sum(mb_y_len).item()
  30. total_loss += loss.item() * num_words
  31. total_num_words += num_words
  32. print("Evaluation loss", total_loss/total_num_words)
  33. def train(model, data, num_epochs=20):
  34. for epoch in range(num_epochs):
  35. model.train()
  36. total_num_words = total_loss = 0.
  37. for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
  38. mb_x = torch.from_numpy(mb_x).to(device).long()
  39. mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
  40. mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
  41. mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
  42. mb_y_len = torch.from_numpy(mb_y_len - 1).to(device).long()
  43. mb_y_len[mb_y_len <= 0] = 1
  44. mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
  45. mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
  46. mb_out_mask = mb_out_mask.float()
  47. loss = loss_fn(mb_pred, mb_output, mb_out_mask)
  48. num_words = torch.sum(mb_y_len).item()
  49. total_loss += loss.item() * num_words
  50. total_num_words += num_words
  51. # 更新模型
  52. optimizer.zero_grad()
  53. loss.backward()
  54. torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
  55. optimizer.step()
  56. if it % 100 == 0:
  57. print("Epoch", epoch, "iteration", it, "loss", loss.item())
  58. print("Epoch", epoch, "Training loss", total_loss / total_num_words)
  59. if epoch % 5 == 0:
  60. evaluate(model, dev_data)
  61. train(model, train_data, num_epochs=20)

翻译测试

  1. def translate_dev(i):
  2. en_sent = " ".join([inv_en_dict[w] for w in dev_en[i]])
  3. print(en_sent)
  4. cn_sent = " ".join([inv_cn_dict[w] for w in dev_cn[i]])
  5. print("".join(cn_sent))
  6. mb_x = torch.from_numpy(np.array(dev_en[i]).reshape(1, -1)).long().to(device)
  7. mb_x_len = torch.from_numpy(np.array([len(dev_en[i])])).long().to(device)
  8. bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)
  9. translation, attn = model.translate(mb_x, mb_x_len, bos)
  10. translation = [inv_cn_dict[i] for i in translation.data.cpu().numpy().reshape(-1)]
  11. trans = []
  12. for word in translation:
  13. if word != "EOS":
  14. trans.append(word)
  15. else:
  16. break
  17. print("".join(trans))
  18. for i in range(100,120):
  19. translate_dev(i)
  20. print()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/379270
推荐阅读
相关标签
  

闽ICP备14008679号