当前位置:   article > 正文

NLP笔记(8)——轻松构建Seq2Seq模型,保姆级教学!

seq2seq

图片

 一、Seq2Seq的原理

Sequence to sequence (seq2seq)是由encoder(编码器)和decoder(解码器)两个RNN的组成的。其中encoder负责对输入句子的理解,转化为context vector,decoder负责对理解后的句子的向量进行处理,解码,获得输出。上述的过程和我们大脑理解东西的过程很相似,听到一句话,理解之后,尝试组装答案,进行回答。那么此时,就有一个问题,在encoder的过程中得到的context vector作为decoder的输入,那么这样一个输入,怎么能够得到多个输出呢?其实就是当前一步的输出,作为下一个单元的输入,然后得到结果。

  1. outputs = []
  2. while True:
  3. output = decoderd(output)
  4. outputs.append(output)

在训练数据集中,可以再输出的最后面添加一个结束符<END>,如果遇到该结束符,则可以终止循环。

  1. outputs = []
  2. while output!="<END>":
  3. output = decoderd(output)
  4. outputs.append(output)

Seq2seq模型中的encoder接受一个长度为M的序列,得到1个 context vector,之后decoder把这一个context vector转化为长度为N的序列作为输出,从而构成一个M to N的模型,能够处理很多不定长输入输出的问题,比如:文本翻译,问答,文章摘要,关键字写诗等等

二、Seq2Seq模型的实现

2.1.模型需求及实现流程

需求:完成一个模型,实现往模型输入一串数字,输出这串数字+0

例如

  • 输入123456789,输出1234567890

  • 输入52555568,输出525555680

流程:

首先文本转化为序列,使用序列,准备数据集,准备Dataloader。然后完成编码器和解码器。然后完成seq2seq模型。然后完成模型训练的逻辑,进行训练。然后完成模型评估的逻辑,进行模型评估。

图片

2.2.模型的实现

1.创建配置文件(config.py)

  1. batch_size = 512
  2. max_len = 10
  3. dropout = 0
  4. embedding_dim = 100
  5. hidden_size = 64

2.文本转化为序列(word_sequence.py)

由于输入的是数字,为了把这写数字和词典中的真实数字进行对应,可以把这些数字理解为字符串。所以需要先把字符串对应为数字,然后把数字转化为字符串。

  1. class NumSequence:
  2. UNK_TAG = "UNK"
  3. PAD_TAG = "PAD"
  4. EOS_TAG = "EOS"
  5. SOS_TAG = "SOS"
  6. UNK = 0
  7. PAD = 1
  8. EOS = 2
  9. SOS = 3
  10. def __init__(self):
  11. self.dict = {
  12. self.UNK_TAG : self.UNK,
  13. self.PAD_TAG : self.PAD,
  14. self.EOS_TAG : self.EOS,
  15. self.SOS_TAG : self.SOS
  16. }
  17. for i in range(10):
  18. self.dict[str(i)] = len(self.dict)
  19. self.index2word = dict(zip(self.dict.values(),self.dict.keys()))
  20. def __len__(self):
  21. return len(self.dict)
  22. def transform(self,sequence,max_len=None,add_eos=False):
  23. sequence_list = list(str(sequence))
  24. seq_len = len(sequence_list)+1 if add_eos else len(sequence_list)
  25. if add_eos and max_len is not None:
  26. assert max_len>= seq_len, "max_len 需要大于seq+eos的长度"
  27. _sequence_index = [self.dict.get(i,self.UNK) for i in sequence_list]
  28. if add_eos:
  29. _sequence_index += [self.EOS]
  30. if max_len is not None:
  31. sequence_index = [self.PAD]*max_len
  32. sequence_index[:seq_len] = _sequence_index
  33. return sequence_index
  34. else:
  35. return _sequence_index
  36. def inverse_transform(self,sequence_index):
  37. result = []
  38. for i in sequence_index:
  39. if i==self.EOS:
  40. break
  41. result.append(self.index2word.get(int(i),self.UNK_TAG))
  42. return result
  43. # 实例化
  44. num_sequence = NumSequence()
  45. if __name__ == '__main__':
  46. num_sequence = NumSequence()
  47. print(num_sequence.dict)
  48. print(num_sequence.index2word)
  49. print(num_sequence.transform("1231230",add_eos=True))

3.数据集(dataset.py)

随机创建[0,100000000]的整型,准备数据集,运行程序可以看到大部分的数字长度为8,在目标值后面添加上0和EOS之后,最大长度为10。所以config配置文件的max_len=10。

  1. from torch.utils.data import Dataset,DataLoader
  2. import numpy as np
  3. from word_sequence import num_sequence
  4. import torch
  5. import config
  6. class RandomDataset(Dataset):
  7. def __init__(self):
  8. super(RandomDataset,self).__init__()
  9. self.total_data_size = 500000
  10. np.random.seed(10)
  11. self.total_data = np.random.randint(1,100000000,size=[self.total_data_size])
  12. def __getitem__(self, idx):
  13. input = str(self.total_data[idx])
  14. return input, input+ "0",len(input),len(input)+1
  15. def __len__(self):
  16. return self.total_data_size
  17. def collate_fn(batch):
  18. #1. 对batch进行排序,按照长度从长到短的顺序排序
  19. batch = sorted(batch,key=lambda x:x[3],reverse=True)
  20. input,target,input_length,target_length = zip(*batch)
  21. #2.进行padding的操作
  22. input = torch.LongTensor([num_sequence.transform(i,max_len=config.max_len) for i in input])
  23. target = torch.LongTensor([num_sequence.transform(i,max_len=config.max_len,add_eos=True) for i in target])
  24. input_length = torch.LongTensor(input_length)
  25. target_length = torch.LongTensor(target_length)
  26. return input,target,input_length,target_length
  27. data_loader = DataLoader(dataset=RandomDataset(),batch_size=config.batch_size,collate_fn=collate_fn,drop_last=True)
  28. if __name__ == '__main__':
  29. data_loader = DataLoader(dataset=RandomDataset(),batch_size=config.batch_size,drop_last=True)
  30. for idx,(input,target,input_lenght,target_length) in enumerate(data_loader):
  31. print(idx) #输出
  32. print(input) #输入
  33. print(target) #输出,后面加0
  34. print(input_lenght) #输入长度
  35. print(target_length) #输出长度
  36. break

4.编码器(encoder.py)

编码器(encoder)的目的就是为了对文本进行编码,把编码后的结果交给后续的程序使用,所以在这里可以使用Embedding+GRU的结构,使用最后一个time step的输出(hidden state)作为句子的编码结果。

图片

  1. import torch.nn as nn
  2. from word_sequence import num_sequence
  3. import config
  4. class NumEncoder(nn.Module):
  5. def __init__(self):
  6. super(NumEncoder,self).__init__()
  7. self.vocab_size = len(num_sequence)
  8. self.dropout = config.dropout
  9. self.embedding_dim = config.embedding_dim
  10. self.embedding = nn.Embedding(num_embeddings=self.vocab_size,embedding_dim=self.embedding_dim,padding_idx=num_sequence.PAD)
  11. self.gru = nn.GRU(input_size=self.embedding_dim,
  12. hidden_size=config.hidden_size,
  13. num_layers=1,
  14. batch_first=True,
  15. dropout=config.dropout)
  16. def forward(self, input,input_length):
  17. embeded = self.embedding(input)
  18. embeded = nn.utils.rnn.pack_padded_sequence(embeded,lengths=input_length,batch_first=True)
  19. out,hidden = self.gru(embeded)
  20. out,outputs_length = nn.utils.rnn.pad_packed_sequence(out,batch_first=True,padding_value=num_sequence.PAD)
  21. return out,hidden

5.解码器(decoder.py)

解码器主要负责实现对编码之后结果的处理,得到预测值,为后续计算损失做准备。解码器也是一个RNN,即也可以使用LSTM or GRU的结构。

  1. import torch
  2. import torch.nn as nn
  3. import config
  4. import random
  5. import torch.nn.functional as F
  6. from word_sequence import num_sequence
  7. class NumDecoder(nn.Module):
  8. def __init__(self):
  9. super(NumDecoder,self).__init__()
  10. self.max_seq_len = config.max_len
  11. self.vocab_size = len(num_sequence)
  12. self.embedding_dim = config.embedding_dim
  13. self.dropout = config.dropout
  14. self.embedding = nn.Embedding(num_embeddings=self.vocab_size,embedding_dim=self.embedding_dim,padding_idx=num_sequence.PAD)
  15. self.gru = nn.GRU(input_size=self.embedding_dim,
  16. hidden_size=config.hidden_size,
  17. num_layers=1,
  18. batch_first=True,
  19. dropout=self.dropout)
  20. self.log_softmax = nn.LogSoftmax()
  21. self.fc = nn.Linear(config.hidden_size,self.vocab_size)
  22. def forward(self, encoder_hidden,target,target_length):
  23. # encoder_hidden [batch_size,hidden_size]
  24. # target [batch_size,seq-len]
  25. decoder_input = torch.LongTensor([[num_sequence.SOS]]*config.batch_size)
  26. # print("decoder_input size:",decoder_input.size())
  27. decoder_outputs = torch.zeros(config.batch_size,config.max_len,self.vocab_size) #[seq_len,batch_size,14]
  28. decoder_hidden = encoder_hidden #[batch_size,hidden_size]
  29. for t in range(config.max_len):
  30. decoder_output_t , decoder_hidden = self.forward_step(decoder_input,decoder_hidden)
  31. # print(decoder_output_t.size(),decoder_hidden.size())
  32. # print(decoder_outputs.size())
  33. decoder_outputs[:,t,:] = decoder_output_t
  34. use_teacher_forcing = random.random() > 0.5
  35. if use_teacher_forcing:
  36. decoder_input =target[:,t].unsqueeze(1) #[batch_size,1]
  37. else:
  38. value, index = torch.topk(decoder_output_t, 1) # index [batch_size,1]
  39. decoder_input = index
  40. # print("decoder_input size:",decoder_input.size(),use_teacher_forcing)
  41. return decoder_outputs,decoder_hidden
  42. def forward_step(self,decoder_input,decoder_hidden):
  43. """
  44. :param decoder_input:[batch_size,1]
  45. :param decoder_hidden: [1,batch_size,hidden_size]
  46. :return: out:[batch_size,vocab_size],decoder_hidden:[1,batch_size,didden_size]
  47. """
  48. embeded = self.embedding(decoder_input) #embeded: [batch_size,1 , embedding_dim]
  49. # print("forworad step embeded:",embeded.size())
  50. out,decoder_hidden = self.gru(embeded,decoder_hidden) #out [1, batch_size, hidden_size]
  51. # print("forward_step out size:",out.size()) #[1, batch_size, hidden_size]
  52. out = out.squeeze(0)
  53. out = F.log_softmax(self.fc(out),dim=-1)#[batch_Size, vocab_size]
  54. out = out.squeeze(1)
  55. # print("out size:",out.size(),decoder_hidden.size())
  56. return out,decoder_hidden
  57. def evaluation(self,encoder_hidden): #[1, 20, 14]
  58. # target = target.transpose(0, 1) # batch_first = False
  59. batch_size = encoder_hidden.size(1)
  60. decoder_input = torch.LongTensor([[num_sequence.SOS] * batch_size])
  61. # print("decoder start input size:",decoder_input.size()) #[1, 20]
  62. decoder_outputs = torch.zeros(batch_size,config.max_len, self.vocab_size) # [seq_len,batch_size,14]
  63. decoder_hidden = encoder_hidden
  64. for t in range(config.max_len):
  65. decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
  66. decoder_outputs[:,t,:] = decoder_output_t
  67. value, index = torch.topk(decoder_output_t, 1) # index [20,1]
  68. decoder_input = index.transpose(0, 1)
  69. # print("decoder_outputs size:",decoder_outputs.size())
  70. # # 获取输出的id
  71. decoder_indices =[]
  72. # decoder_outputs = decoder_outputs.transpose(0,1) #[batch_size,seq_len,vocab_size]
  73. # print("decoder_outputs size",decoder_outputs.size())
  74. for i in range(decoder_outputs.size(1)):
  75. value,indices = torch.topk(decoder_outputs[:,i,:],1)
  76. # print("indices size",indices.size(),indices)
  77. # indices = indices.transpose(0,1)
  78. decoder_indices.append(int(indices[0][0].data))
  79. return decoder_indices

 6.完成seq2seq模型(seq2seq.py)

  1. import torch
  2. import torch.nn as nn
  3. class Seq2Seq(nn.Module):
  4. def __init__(self,encoder,decoder):
  5. super(Seq2Seq,self).__init__()
  6. self.encoder = encoder
  7. self.decoder = decoder
  8. def forward(self, input,target,input_length,target_length):
  9. encoder_outputs,encoder_hidden = self.encoder(input,input_length)
  10. decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,target_length)
  11. return decoder_outputs,decoder_hidden
  12. def evaluation(self,inputs,input_length):
  13. encoder_outputs,encoder_hidden = self.encoder(inputs,input_length)
  14. decoded_sentence = self.decoder.evaluation(encoder_hidden)
  15. return decoded_sentence

7.完成训练

  1. import torch
  2. import config
  3. from torch import optim
  4. import torch.nn as nn
  5. from encoder import NumEncoder
  6. from decoder import NumDecoder
  7. from seq2seq import Seq2Seq
  8. from dataset import data_loader as train_dataloader
  9. from word_sequence import num_sequence
  10. from tqdm import tqdm
  11. encoder = NumEncoder()
  12. decoder = NumDecoder()
  13. model = Seq2Seq(encoder,decoder)
  14. for name, param in model.named_parameters():
  15. if 'bias' in name:
  16. torch.nn.init.constant_(param, 0.0)
  17. elif 'weight' in name:
  18. torch.nn.init.xavier_normal_(param)
  19. optimizer = optim.Adam(model.parameters())
  20. criterion= nn.NLLLoss(ignore_index=num_sequence.PAD,reduction="mean")
  21. def get_loss(decoder_outputs,target):
  22. target = target.view(-1) #[batch_size*max_len]
  23. decoder_outputs = decoder_outputs.view(config.batch_size*config.max_len,-1)
  24. return criterion(decoder_outputs,target)
  25. def train(epoch):
  26. total_loss = 0
  27. correct = 0
  28. total = 0
  29. progress_bar = tqdm(total=len(train_dataloader), desc='Train Epoch {}'.format(epoch), unit='batch')
  30. for idx, (input, target, input_length, target_len) in enumerate(train_dataloader):
  31. optimizer.zero_grad()
  32. ##[seq_len,batch_size,vocab_size] [batch_size,seq_len]
  33. decoder_outputs, decoder_hidden = model(input, target, input_length, target_len)
  34. loss = get_loss(decoder_outputs, target)
  35. total_loss += loss.item()
  36. loss.backward()
  37. optimizer.step()
  38. _, predicted = torch.max(decoder_outputs.data, 2)
  39. correct += (predicted == target).sum().item()
  40. total += target.size(0) * target.size(1)
  41. acc = 100 * correct / total
  42. avg_loss = total_loss / (idx + 1)
  43. progress_bar.set_postfix({'loss': avg_loss, 'acc': '{:.2f}%'.format(acc)})
  44. progress_bar.update()
  45. progress_bar.close()
  46. torch.save(model.state_dict(), "models/seq2seq_model.pkl")
  47. torch.save(optimizer.state_dict(), 'models/seq2seq_optimizer.pkl')
  48. if __name__ == '__main__':
  49. for i in range(10):
  50. train(i)

图片

8.进行评估

随机生成10000个测试集进行模型的验证,然后输入一串数字观察输出结果

  1. import torch
  2. from encoder import NumEncoder
  3. from decoder import NumDecoder
  4. from seq2seq import Seq2Seq
  5. from word_sequence import num_sequence
  6. import random
  7. encoder = NumEncoder()
  8. decoder = NumDecoder()
  9. model = Seq2Seq(encoder,decoder)
  10. model.load_state_dict(torch.load("models/seq2seq_model.pkl"))
  11. def evaluate():
  12. correct = 0
  13. total = 0
  14. for i in range(10000):
  15. test_words = random.randint(1,100000000)
  16. test_word_len = [len(str(test_words))]
  17. _test_words = torch.LongTensor([num_sequence.transform(test_words)])
  18. decoded_indices = model.evaluation(_test_words,test_word_len)
  19. result = num_sequence.inverse_transform(decoded_indices)
  20. if str(test_words)+"0" == "".join(result):
  21. correct += 1
  22. total += 1
  23. accuracy = correct/total
  24. print("10000个测试集的Acc: ", accuracy)
  25. def predict():
  26. test_word = input("Enter a number to predict: ")
  27. test_word_len = [len(test_word)]
  28. _test_word = torch.LongTensor([num_sequence.transform(int(test_word))])
  29. decoded_indices = model.evaluation(_test_word,test_word_len)
  30. result = num_sequence.inverse_transform(decoded_indices)
  31. print("Prediction: ", "".join(result))
  32. if __name__ == '__main__':
  33. evaluate()
  34. predict()

图片

 按照上面步骤一步步进行操作,成功运行是没有问题的,如果想直接获取源代码进行研究,可以关注下面公众号联系~会不定期发布相关设计内容包括但不限于如下内容:信号处理、通信仿真、算法设计、matlab appdesigner,gui设计、simulink仿真......希望能帮到你!

5a8015ddde1e41418a38e958eb12ecbd.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/506814
推荐阅读
相关标签
  

闽ICP备14008679号