当前位置:   article > 正文

Transformer代码(Pytorch实现和详解)!!!_transformer代码pytorch

transformer代码pytorch

这个博主讲得很不错!:https://www.cnblogs.com/kongen/p/18088002

视频:Transformer的PyTorch实现_哔哩哔哩_bilibili

如果是刚接触Transformer,强烈建议去把上边两个看了!!!在此之前,希望你能仔细读2遍原文!!!

这里其实想讲一下为什么通过自注意力机制,就能够预测出来目标值了。一开始我也比较懵懵懂懂,毕竟刚接触, 只知道我的输入a = "我 有 一只 猫" 经过encoder 和 decoder 之后,就得到了b = "I have a cat ", 后来想了想,我觉得大致是这样的,Encoder里边的Multi-Head Attention,得到了编码器输入的注意力权重,也就是输入序列a中每个单词对其他单词的注意力权重;同理Decoder的第一个Multi-Head Attention 也是得到目标序列中,各个单词之间的注意力权重。Decoder中的第二个Multi-Head Attention是将Encoder 和 Decoder 两者结合起来计算注意力权重,这样就能得到源句子中单词,对应目标句子中的单词的权重,最后转换为概率,概率最大的目标单词就是我们的答案。如果扩展一下,分别构建源语言词汇表(src_vocab)和目标语言词汇表(tgt_vocab),我们经过多轮训练之后就能得到比较准确的映射,知道最大概率翻译成哪个target词汇。我建议,先大致看一下理论,然后在代码实现里边找细节!

好,正式开始我们的主题transformer的pytorch代码实现,首先我会分每个部分分别讲解代码,每个部分都是我觉得比较关键的点,所以顾及不了所有点,如果有不理解的,可以在评论区向我提问,很乐意讨论,完整代码放到最后。另外,我建议你在实现代码的时候,可以单独创建一个test.py文件用来测试,将每一个部分的数据打印出来看看是什么样子,尤其是你存有疑惑的数据!

导库

  1. import math
  2. import torch
  3. import numpy as np
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import torch.utils.data as Data

数据预处理以及参数设置

为了方便理解,模型没有使用大数据集,用了两对“德译英”的例子,下面是代码中的操作讲解:

  • 分别构建源语言(source)和目标语言(target)的词汇表(词汇表就是字典,形式:“word”:number),这里的词汇表是手动写的,正常大数据集需要代码。
  • src_len就是源句子的长度是5,tgt_len就是目标句子的长度是6。
  • 'P' 就是填充项,我们要统一句子长度,但并不是每一个句子都有那么长,所以不够的用'P'填充。
  • make_data()的作用是将原来的单词转化为对应在词汇表中的数字,并生成enc_inputs,dec_inputs,dec_outputs这几个数据,务必要记住他们的形状
  • 自定义数据集MyDataSet类,方便管理数据集;创建DataLoader,用来生成mini-batch。
  1. # S: Symbol that shows starting of decoding input
  2. # E: Symbol that shows starting of decoding output
  3. # P: Symbol that will fill in blank sequence if current batch data size is short than time steps
  4. sentences = [
  5. #enc_input dec_input dec_output
  6. ['ich mochte ein bier P','S i want a beer .','i want a beer . E'],
  7. ['ich mochte ein cola P','S i want a coke .','i want a coke . E']
  8. ]
  9. #Padding Should be Zero
  10. src_vocab = {'P' : 0,'ich' : 1,'mochte':2,'ein':3,'bier':4,'cola':5}
  11. src_vocab_size = len(src_vocab)
  12. tgt_vocab = {'P' : 0,'i' : 1,'want':2,'a':3,'beer':4,'coke':5,'S':6,'E':7,'.':8}
  13. tgt_vocab_size = len(tgt_vocab)
  14. idx2word = {i:w for i,w in enumerate(tgt_vocab)}
  15. src_len = 5#enc_input max sequence length
  16. tgt_len = 6#dec_input(=dec_output) max sequence length
  17. def make_data(sentences):
  18. enc_inputs,dec_inputs,dec_outputs = [],[],[]
  19. for i in range(len(sentences)):
  20. enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
  21. dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
  22. dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] ## [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]
  23. enc_inputs.extend(enc_input)
  24. dec_inputs.extend(dec_input)
  25. dec_outputs.extend(dec_output)
  26. return torch.LongTensor(enc_inputs),torch.LongTensor(dec_inputs),torch.LongTensor(dec_outputs)
  27. enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
  28. class MyDataSet(Data.Dataset):
  29. def __init__(self,enc_inputs,dec_inputs,dec_outputs):
  30. super(MyDataSet, self).__init__()
  31. self.enc_inputs = enc_inputs
  32. self.dec_inputs = dec_inputs
  33. self.dec_outputs = dec_outputs
  34. def __len__(self):
  35. return self.enc_inputs.shape[0]
  36. def __getitem__(self, idx):
  37. return self.enc_inputs[idx],self.dec_inputs[idx],self.dec_outputs[idx]
  38. mydataset = MyDataSet(enc_inputs,dec_inputs,dec_outputs)
  39. loader = Data.DataLoader(mydataset,2,shuffle = True)

模型参数

下面变量的含义依次是:

  • d_model:词嵌入的维度(= 位置嵌入维度)
  • d_ff:Feed Forward中两层linear中间的过渡维度(512 -> 2048 -> 512)
  • d_k、d_v:分别是K 、V的维度,其中Q和K相等的就省略了
  • n_layers:EncoderLayer的数量,也就是blocks的数量
  • n_heads:Multi-Head Attention 的头数
  1. #Transformer Parameters
  2. d_model = 512 # Embedding Size (= Positional Size)
  3. d_ff =2048 # Feed Forward(512 -> 2048 ->512)
  4. d_k = d_v = 64 # (d_k=d_q),dimension of qkv
  5. n_layers = 6 # number of encoder-layer(=n blocks)
  6. n_heads = 8 # number of heads in Multi-Head Attention

Positional Encoding

位置编码模块的过程是这样的,在他之前对输入序列已经进行了词嵌入,所以该模块输入的是word_embedding,形状为:[batch_size, src_len, d_model],而位置编码是写死的,在模块初始化的时候生成,将pos_embedding + word_embedding,然后输出,输出的形状:[batch_size,src_len, d_model ],得到了经过word_embedding 和 pos_embedding 的输入表示

  1. class PositionalEncoding(nn.Module):
  2. def __init__(self,d_model ,dropout = 0.1,max_len = 5000):
  3. self.dropout = nn.Dropout(p = dropout)
  4. pe = torch.zeros(max_len, d_model)
  5. position = torch.arange(0,max_len, dtype = torch.float).unsqueeze(1)
  6. div_term = torch.exp(torch.arange(0, d_model, 2).float()*(-math.log(10000.0) / d_model))
  7. pe[:, 0::2] = torch.sin(position * div_term)
  8. pe[:, 1::2] = torch.cos(position * div_term)
  9. pe = pe.unsqueeze(0).transpose(0,1)
  10. self.register_buffer('pe',pe)
  11. def forward(self,x):
  12. '''
  13. :param x: [seq_len, batch_size, d_model]
  14. :return:
  15. '''
  16. #pe[:x.size(0),:] -- [seq_len, 1, d_model]
  17. x = x + self.pe[:x.size(0),:]
  18. return self.dropout(x)

计算公式如下:

  • div_term表示括号里的分母项,这里用exp对公式做了变形,(建议手推一下)求得了div_term。
  • i = 0,1,2,...,d_model/2, 。pe[maxlen, d_model]的第2个dim中,每个奇数维度的值对应一个cos,每个偶数维度的值对应一个sin,这样正好d_model个维度。实际操作中就是在0~d_model之间, 取步长为2,取得[0,2,4,...... ,d_model-2],总共d_model/2个,分别做sin、cos,这样就是d_model个,分别放到pe的d_model个维度上。

这里想说一下,具体如何将词嵌入之后的输入x 加上位置编码pe的。

首先说一下pe,即position embedding。1)pe创建的时候, 形状: [maxlen, d_model], 表示的是对第0 - maxlen 位置的单词进行编码,每个单词维度是d_model,当然实际可能每个句子不到maxlen, 后边会截取不用担心。 2)我们对pe进行编码之后,在dim=1增加了一个维度,形状变成了: [max_len, 1, d_model], 。 3)pe[:x.size(0),:] 其实就是取了跟句子长度seq_len一样大小,pe的形状变为:[seqlen, 1, d_model] 。这里的输入x的形状为:[seq_len, batch_size, d_model]  (传入参数的时候改变了形状,将dim0和dim1做了交换),如下图:

x + pe[:x.size(0),:]   pe的第2维度的1会广播到batch_size大小。batch_size 就是几个句子,相当于对每个句子中的每个位置对应的单词都加上了位置编码。

[seq_len, batch_size, d_model] + [seqlen, 1, d_model] -》最后形状为:[seq_len, batch_size, d_model]。最后返回x,后边的代码会对x进行x.transpose(0,1),得到经过word_emb、pos_emb之后的编码输入:[batch_size, seq_len, d_model]

Pad Mask

这里的作用就是Mask Pad,即遮掩掉填充项,让其他单词对于填充项'P'的注意力权重几乎为0。

  1. def get_attn_pad_mask(seq_q, seq_k): # Mask Pad
  2. '''
  3. :param seq_q: [batch_size, seq_len]
  4. :param seq_k: [batch_size, seq_len]
  5. seq_len could be src_len or it could be tgt_len
  6. seq_len in seq_q and seq_len in seq_k maybe not equal
  7. '''
  8. batch_size, len_q = seq_q.size()
  9. batch_size, len_k = seq_k.size()
  10. #eq(zero) is Pad token
  11. pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) #[batch_size,1, len_k] # True is masked
  12. return pad_attn_mask.expand(batch_size, len_q, len_k)#[batch_size, len_q, len_k]

 Subsquence Mask

为了保证当前位置单词的翻译不考虑位置之后的信息,需要将当前位置的词汇,对它后边位置的注意力给mask掉。举个例子"我 爱 你"->"I love you" 。解码器是输入"我",预测出"I",然后输入"我 爱",预测出"I love",接着输入'''我 爱 你',预测出"I love you"。不同于RNN中的循环操作,self-attention,没有循环,都是并行计算注意力权重,但是mask掉当前词汇对其后的词汇的注意力,这样就能实现翻译当前位置的词汇时,不考虑后边的信息,如下图。

  1. def get_attn_subsequence_mask(seq):
  2. '''
  3. seq: [batch_size, tgt_len]
  4. '''
  5. attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
  6. subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
  7. subsequence_mask = torch.from_numpy(subsequence_mask).byte()
  8. return subsequence_mask # [batch_size, tgt_len, tgt_len]

ScaledDotProductAttention

主要操作就是:Q 和 K^T 进行内积计算,sqrt(d_k)进行缩放得到scores,对注意力分数进行掩码操作,让不该关注的地方置成很大的负数,进行softmax操作转化为注意力权重attn, 将attn和V做矩阵乘法,得到自注意力的输出。注意代码操作的时候是把head作为一个维度加进去了,这样结果就是多个heads做self-attention的结果拼接得到的,也就是concat操作之后得到的结果。

  1. class ScaledDotProductAttention(nn.Module):
  2. def __init__(self):
  3. super(ScaledDotProductAttention, self).__init__()
  4. def forward(self, Q, K, V, attn_mask):
  5. '''
  6. :param Q:[batch_size, n_heads, len_q, d_k]
  7. :param K:[batch_size, n_heads, len_k, d_k]
  8. :param V:[batch_size, n_heads, len_v(=len_k), d_v]
  9. :param attn_mask:[batch_size, n_heads, seq_len, seq_len]
  10. :return:
  11. '''
  12. scores = torch.matmul(Q,K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
  13. scores.masked_fill_(attn_mask, -1e9)# Fills elements of self tensor with value where mask is True.
  14. attn = nn.Softmax(dim = -1)(scores)
  15. context = torch.matmul(attn, V) ## [batch_size, n_heads, len_q, d_v]
  16. return context, attn

MultiHeadAttention

self.W_Q、self.W_K、self.W_V 是三个线性层变换,将我们的输入enc_inputs,映射成QKV,他这里映射的维度是d*n_heads,发现了没有?这是实际操作跟原理想的不太一样的,它是直接就把输入enc_inputs投影到n_heads=8个版本了。

QKV的形状经过一些列变换成了:[B,H,S,W], 如何理解呢?你可以先看后两个维度,S是序列长度,也就是单词个数,W是QKV的维度,把他看成一个小矩阵。然后再看H这个维度,就相当于有很多并排的小矩阵。B是句子个数,每个句子都有8个映射版本。

  1. class MultiHeadAttention(nn.Module):
  2. def __init__(self):
  3. super(MultiHeadAttention, self).__init__()
  4. self.W_Q = nn.Linear(d_model, d_k * n_heads, bias = False)#一次性做8个heads的qkv映射
  5. self.W_K = nn.Linear(d_model, d_k * n_heads, bias = False)
  6. self.W_V = nn.Linear(d_model, d_v * n_heads, bias = False)
  7. self.fc = nn.Linear(n_heads * d_v, d_model, bias = False)
  8. def forward(self,input_Q,input_K,input_V,attn_mask):
  9. '''
  10. :param input_Q: [batch_size, len_q, d_model]
  11. :param input_K: [batch_size, len_k, d_model]
  12. :param input_V: [batch_size, len_v, d_model]
  13. :param attn_mask: [batch_size, seq_len, seq_len]
  14. :return:
  15. '''
  16. residual, batch_size = input_Q, input_Q.size(0)
  17. #(B, S, D) -proj->(B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
  18. Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
  19. K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
  20. V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
  21. attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]
  22. #context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
  23. context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
  24. context = context.transpose(1,2).reshape(batch_size, -1, n_heads * d_v) # context:[batch_size,len_q, n_heads*d_v]
  25. output = self.fc(context) # output:[batch_size, len_q, d_model]
  26. return nn.LayerNorm(d_model).cuda()(output + context), attn

FeedForward Layer

两个线性层,中间用ReLu()函数激活,最后进行residual和LayerNorm操作,没什么好讲的。

  1. class PoswiseFeedForwardNet(nn.Module):
  2. def __init__(self):
  3. super(PoswiseFeedForwardNet, self).__init__()
  4. self.fc = nn.Sequential(nn.Linear(d_model, dff, bias = False),
  5. nn.ReLU(),
  6. nn.Linear(d_ff, d_model, bias = False))
  7. def forward(self, inputs):
  8. '''
  9. :param inputs: [batch_size, seq_len, d_model]
  10. :return:
  11. '''
  12. residual = inputs
  13. output = self.fc(inputs)
  14. return nn.LayerNorm(d_model).cuda()( output + residual)# [batch_size, seq_len, d_model]

Encoder Layer

包含两个sub-layers:Multi-Head Attention  和  FeedForward。

  1. class EncoderLayer(nn.Module):
  2. def __init__(self):
  3. super(EncoderLayer, self).__init__()
  4. self.enc_self_attn = MultiHeadAttention()
  5. self.pos_ffn = PoswiseFeedForwardNet()
  6. def forward(self,enc_inputs, enc_self_attn_mask):
  7. '''
  8. :param enc_inputs: [batch_size, src_len, d_model]
  9. :param enc_self_attn_mask:[batch_size, src_len, src_len]
  10. :return:
  11. '''
  12. #enc_outputs:[batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
  13. enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
  14. enc_outputs = self.pos_ffn(enc_outputs)
  15. return enc_outputs, attn

DecoderLayer

  1. class DecoderLayer(nn.Module):
  2. def __init__(self):
  3. super(DecoderLayer, self).__init__()
  4. self.dec_self_attn = MultiHeadAttention()
  5. self.dec_enc_attn = MultiHeadAttention()
  6. self.pos_ffn = PoswiseFeedForwardNet()
  7. def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
  8. '''
  9. dec_inputs: [batch_size, tgt_len, d_model]
  10. enc_outputs: [batch_size, src_len, d_model]
  11. dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
  12. dec_enc_attn_mask: [batch_size, tgt_len, src_len]
  13. '''
  14. # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
  15. dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
  16. # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
  17. dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
  18. dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
  19. return dec_outputs, dec_self_attn, dec_enc_attn

Encoder

由两个嵌入层(src_emb、pos_emb)、n_layers个EncoderLayer组成。

  1. class Encoder(nn.Module):
  2. def __init__(self):
  3. super(Encoder, self).__init__()
  4. self.src_emb = nn.Embedding(src_vocab_size, d_model)
  5. self.pos_emb = PositionalEncoding(d_model)
  6. self.layers = nn.Modulelist([EncoderLayer() for _ in range(n_layers)])
  7. def forward(self,enc_inputs):
  8. '''
  9. :param enc_inputs: [batch_size,src_len]
  10. '''
  11. enc_outputs = self.src_emb(enc_inputs) #[batch_size, src_len, d_model]
  12. enc_outputs = self.src_emb(enc_outputs.transpose(0,1)).transpose(0,1) #[batch_size, src_len, d_model]
  13. enc_self_attn_mask = get_attn_pad_mask(enc_inputs,enc_inputs)
  14. enc_self_attns = []
  15. for layer in self.layers:
  16. # enc_outputs: [batch_size, src_len, d_model] , enc_self_attn: [batch_size, n_heads, src_len, ser_len]
  17. enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
  18. enc_self_attns.append(enc_self_attn)
  19. return enc_outputs, enc_self_attns

Decoder

  1. class Decoder(nn.Module):
  2. def __init__(self):
  3. super(Decoder, self).__init__()
  4. self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
  5. self.pos_emb = PositionalEncoding(d_model)
  6. self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
  7. def forward(self, dec_inputs, enc_inputs, enc_outputs):
  8. '''
  9. dec_inputs: [batch_size, tgt_len]
  10. enc_intpus: [batch_size, src_len]
  11. enc_outputs: [batsh_size, src_len, d_model]
  12. '''
  13. dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
  14. dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model]
  15. dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
  16. dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
  17. dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda() # [batch_size, tgt_len, tgt_len]
  18. dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]
  19. dec_self_attns, dec_enc_attns = [], []
  20. for layer in self.layers:
  21. # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
  22. dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
  23. dec_self_attns.append(dec_self_attn)
  24. dec_enc_attns.append(dec_enc_attn)
  25. return dec_outputs, dec_self_attns, dec_enc_attns

Transformer

由Encoder、Decoder,以及一个projection组成,其中Encoder部分由图中6个堆叠在一起的EncoderLayer组成(图中标错了),每个EncoderLayer的输入输出都是[batch_size, src_len, d_model],所以可以直接传到下一个EncoderLayer。Decoder输出后,进行projection操作,作用是降维,将d_model降到tgt_vocab_size大小,以便输出概率分布。

  1. class Transformer(nn.Module):
  2. def __init__(self):
  3. super(Transformer, self).__init__()
  4. self.encoder = Encoder().cuda()
  5. self.decoder = Decoder().cuda()
  6. self.projection = nn.torch.Linear(d_model, tgt_vocab_size, bias=False).cuda()
  7. def forward(self,enc_inputs,dec_inputs,dec_outputs):
  8. '''
  9. :param enc_inputs: [batch_size, src_len]
  10. :param dec_inputs: [batch_size, tgt_len]
  11. :return:
  12. '''
  13. # tensor to store decoder outputs
  14. # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
  15. # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
  16. enc_outputs, enc_self_attns = self.encoder(enc_inputs)
  17. #dec_outputs: [batch_size,tgt_len, d_model],dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len],dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
  18. dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs,enc_inputs, enc_outputs)
  19. dec_logits = self.projection(dec_outputs)
  20. return dec_logits.view(-1, dec_logits.size(-1)),enc_self_attns,dec_self_attns, dec_enc_attns

模型训练

  1. model = Transformer().cuda()
  2. criterion = nn.CrossEntropyLoss(ignore_index=0)
  3. optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)
  4. for epoch in range(1000):
  5. for enc_inputs, dec_inputs, dec_outputs in loader:
  6. '''
  7. enc_inputs: [batch_size, src_len]
  8. dec_inputs: [batch_size, tgt_len]
  9. dec_outputs: [batch_size, tgt_len]
  10. '''
  11. enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
  12. # outputs: [batch_size * tgt_len, tgt_vocab_size]
  13. outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
  14. loss = criterion(outputs, dec_outputs.view(-1))
  15. print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
  16. optimizer.zero_grad()
  17. loss.backward()
  18. optimizer.step()

Greedy_decoder

暂时可以注释掉这部分代码,是以另一种方式进行训练,感兴趣的可以去了解一下。

  1. def greedy_decoder(model, enc_input, start_symbol):
  2. """
  3. For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
  4. target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
  5. Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
  6. :param model: Transformer Model
  7. :param enc_input: The encoder input
  8. :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
  9. :return: The target input
  10. """
  11. enc_outputs, enc_self_attns = model.encoder(enc_input)
  12. dec_input = torch.zeros(1, 0).type_as(enc_input.data)
  13. terminal = False
  14. next_symbol = start_symbol
  15. while not terminal:
  16. dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_symbol]],dtype=enc_input.dtype).cuda()],-1)
  17. dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
  18. projected = model.projection(dec_outputs)
  19. prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
  20. next_word = prob.data[-1]
  21. next_symbol = next_word
  22. if next_symbol == tgt_vocab["."]:
  23. terminal = True
  24. print(next_word)
  25. return dec_input
  26. # Test
  27. enc_inputs, _, _ = next(iter(loader))
  28. enc_inputs = enc_inputs.cuda()
  29. for i in range(len(enc_inputs)):
  30. greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])
  31. predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)
  32. predict = predict.data.max(1, keepdim=True)[1]
  33. print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])

完整代码

  1. # -*- coding: utf-8 -*-
  2. """Transformer-Torch
  3. Automatically generated by Colaboratory.
  4. Original file is located at
  5. https://colab.research.google.com/drive/15yTJSjZpYuIWzL9hSbyThHLer4iaJjBD
  6. """
  7. '''
  8. code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612, modify by wmathor
  9. Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
  10. https://github.com/JayParks/transformer
  11. '''
  12. import math
  13. import torch
  14. import numpy as np
  15. import torch.nn as nn
  16. import torch.optim as optim
  17. import torch.utils.data as Data
  18. # S: Symbol that shows starting of decoding input
  19. # E: Symbol that shows starting of decoding output
  20. # P: Symbol that will fill in blank sequence if current batch data size is short than time steps
  21. sentences = [
  22. # enc_input dec_input dec_output
  23. ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
  24. ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
  25. ]
  26. # Padding Should be Zero
  27. src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
  28. src_vocab_size = len(src_vocab)
  29. tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
  30. idx2word = {i: w for i, w in enumerate(tgt_vocab)}
  31. tgt_vocab_size = len(tgt_vocab)
  32. src_len = 5 # enc_input max sequence length
  33. tgt_len = 6 # dec_input(=dec_output) max sequence length
  34. # Transformer Parameters
  35. d_model = 512 # Embedding Size
  36. d_ff = 2048 # FeedForward dimension
  37. d_k = d_v = 64 # dimension of K(=Q), V
  38. n_layers = 6 # number of Encoder of Decoder Layer
  39. n_heads = 8 # number of heads in Multi-Head Attention
  40. def make_data(sentences):
  41. enc_inputs, dec_inputs, dec_outputs = [], [], []
  42. for i in range(len(sentences)):
  43. enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
  44. dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
  45. dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]
  46. enc_inputs.extend(enc_input)
  47. dec_inputs.extend(dec_input)
  48. dec_outputs.extend(dec_output)
  49. return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)
  50. enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
  51. class MyDataSet(Data.Dataset):
  52. def __init__(self, enc_inputs, dec_inputs, dec_outputs):
  53. super(MyDataSet, self).__init__()
  54. self.enc_inputs = enc_inputs
  55. self.dec_inputs = dec_inputs
  56. self.dec_outputs = dec_outputs
  57. def __len__(self):
  58. return self.enc_inputs.shape[0]
  59. def __getitem__(self, idx):
  60. return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
  61. loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)
  62. class PositionalEncoding(nn.Module):
  63. def __init__(self, d_model, dropout=0.1, max_len=5000):
  64. super(PositionalEncoding, self).__init__()
  65. self.dropout = nn.Dropout(p=dropout)
  66. pe = torch.zeros(max_len, d_model)
  67. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
  68. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
  69. pe[:, 0::2] = torch.sin(position * div_term)
  70. pe[:, 1::2] = torch.cos(position * div_term)
  71. pe = pe.unsqueeze(0).transpose(0, 1)
  72. self.register_buffer('pe', pe)
  73. def forward(self, x):
  74. '''
  75. x: [seq_len, batch_size, d_model]
  76. '''
  77. x = x + self.pe[:x.size(0), :]
  78. return self.dropout(x)
  79. def get_attn_pad_mask(seq_q, seq_k):
  80. '''
  81. seq_q: [batch_size, seq_len]
  82. seq_k: [batch_size, seq_len]
  83. seq_len could be src_len or it could be tgt_len
  84. seq_len in seq_q and seq_len in seq_k maybe not equal
  85. '''
  86. batch_size, len_q = seq_q.size()
  87. batch_size, len_k = seq_k.size()
  88. # eq(zero) is PAD token
  89. pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k], False is masked
  90. return pad_attn_mask.expand(batch_size, len_q, len_k) # [batch_size, len_q, len_k]
  91. def get_attn_subsequence_mask(seq):
  92. '''
  93. seq: [batch_size, tgt_len]
  94. '''
  95. attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
  96. subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
  97. subsequence_mask = torch.from_numpy(subsequence_mask).byte()
  98. return subsequence_mask # [batch_size, tgt_len, tgt_len]
  99. class ScaledDotProductAttention(nn.Module):
  100. def __init__(self):
  101. super(ScaledDotProductAttention, self).__init__()
  102. def forward(self, Q, K, V, attn_mask):
  103. '''
  104. Q: [batch_size, n_heads, len_q, d_k]
  105. K: [batch_size, n_heads, len_k, d_k]
  106. V: [batch_size, n_heads, len_v(=len_k), d_v]
  107. attn_mask: [batch_size, n_heads, seq_len, seq_len]
  108. '''
  109. scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
  110. scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
  111. attn = nn.Softmax(dim=-1)(scores)
  112. context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
  113. return context, attn
  114. class MultiHeadAttention(nn.Module):
  115. def __init__(self):
  116. super(MultiHeadAttention, self).__init__()
  117. self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
  118. self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
  119. self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
  120. self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
  121. def forward(self, input_Q, input_K, input_V, attn_mask):
  122. '''
  123. input_Q: [batch_size, len_q, d_model]
  124. input_K: [batch_size, len_k, d_model]
  125. input_V: [batch_size, len_v(=len_k), d_model]
  126. attn_mask: [batch_size, seq_len, seq_len]
  127. '''
  128. residual, batch_size = input_Q, input_Q.size(0)
  129. # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
  130. Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # Q: [batch_size, n_heads, len_q, d_k]
  131. K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # K: [batch_size, n_heads, len_k, d_k]
  132. V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # V: [batch_size, n_heads, len_v(=len_k), d_v]
  133. attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]
  134. # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
  135. context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
  136. context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]
  137. output = self.fc(context) # [batch_size, len_q, d_model]
  138. return nn.LayerNorm(d_model).cuda()(output + residual), attn
  139. class PoswiseFeedForwardNet(nn.Module):
  140. def __init__(self):
  141. super(PoswiseFeedForwardNet, self).__init__()
  142. self.fc = nn.Sequential(
  143. nn.Linear(d_model, d_ff, bias=False),
  144. nn.ReLU(),
  145. nn.Linear(d_ff, d_model, bias=False)
  146. )
  147. def forward(self, inputs):
  148. '''
  149. inputs: [batch_size, seq_len, d_model]
  150. '''
  151. residual = inputs
  152. output = self.fc(inputs)
  153. return nn.LayerNorm(d_model).cuda()(output + residual) # [batch_size, seq_len, d_model]
  154. class EncoderLayer(nn.Module):
  155. def __init__(self):
  156. super(EncoderLayer, self).__init__()
  157. self.enc_self_attn = MultiHeadAttention()
  158. self.pos_ffn = PoswiseFeedForwardNet()
  159. def forward(self, enc_inputs, enc_self_attn_mask):
  160. '''
  161. enc_inputs: [batch_size, src_len, d_model]
  162. enc_self_attn_mask: [batch_size, src_len, src_len]
  163. '''
  164. # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
  165. enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
  166. enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
  167. return enc_outputs, attn
  168. class DecoderLayer(nn.Module):
  169. def __init__(self):
  170. super(DecoderLayer, self).__init__()
  171. self.dec_self_attn = MultiHeadAttention()
  172. self.dec_enc_attn = MultiHeadAttention()
  173. self.pos_ffn = PoswiseFeedForwardNet()
  174. def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
  175. '''
  176. dec_inputs: [batch_size, tgt_len, d_model]
  177. enc_outputs: [batch_size, src_len, d_model]
  178. dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
  179. dec_enc_attn_mask: [batch_size, tgt_len, src_len]
  180. '''
  181. # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
  182. dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
  183. # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
  184. dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
  185. dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
  186. return dec_outputs, dec_self_attn, dec_enc_attn
  187. class Encoder(nn.Module):
  188. def __init__(self):
  189. super(Encoder, self).__init__()
  190. self.src_emb = nn.Embedding(src_vocab_size, d_model)
  191. self.pos_emb = PositionalEncoding(d_model)
  192. self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
  193. def forward(self, enc_inputs):
  194. '''
  195. enc_inputs: [batch_size, src_len]
  196. '''
  197. enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
  198. enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
  199. enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
  200. enc_self_attns = []
  201. for layer in self.layers:
  202. # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
  203. enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
  204. enc_self_attns.append(enc_self_attn)
  205. return enc_outputs, enc_self_attns
  206. class Decoder(nn.Module):
  207. def __init__(self):
  208. super(Decoder, self).__init__()
  209. self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
  210. self.pos_emb = PositionalEncoding(d_model)
  211. self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
  212. def forward(self, dec_inputs, enc_inputs, enc_outputs):
  213. '''
  214. dec_inputs: [batch_size, tgt_len]
  215. enc_intpus: [batch_size, src_len]
  216. enc_outputs: [batsh_size, src_len, d_model]
  217. '''
  218. dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
  219. dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model]
  220. dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
  221. dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
  222. dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda() # [batch_size, tgt_len, tgt_len]
  223. dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]
  224. dec_self_attns, dec_enc_attns = [], []
  225. for layer in self.layers:
  226. # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
  227. dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
  228. dec_self_attns.append(dec_self_attn)
  229. dec_enc_attns.append(dec_enc_attn)
  230. return dec_outputs, dec_self_attns, dec_enc_attns
  231. class Transformer(nn.Module):
  232. def __init__(self):
  233. super(Transformer, self).__init__()
  234. self.encoder = Encoder().cuda()
  235. self.decoder = Decoder().cuda()
  236. self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()
  237. def forward(self, enc_inputs, dec_inputs):
  238. '''
  239. enc_inputs: [batch_size, src_len]
  240. dec_inputs: [batch_size, tgt_len]
  241. '''
  242. # tensor to store decoder outputs
  243. # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
  244. # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
  245. enc_outputs, enc_self_attns = self.encoder(enc_inputs)
  246. # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
  247. dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
  248. dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
  249. return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
  250. model = Transformer().cuda()
  251. criterion = nn.CrossEntropyLoss(ignore_index=0)
  252. optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)
  253. for epoch in range(1000):
  254. for enc_inputs, dec_inputs, dec_outputs in loader:
  255. '''
  256. enc_inputs: [batch_size, src_len]
  257. dec_inputs: [batch_size, tgt_len]
  258. dec_outputs: [batch_size, tgt_len]
  259. '''
  260. enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
  261. # outputs: [batch_size * tgt_len, tgt_vocab_size]
  262. outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
  263. loss = criterion(outputs, dec_outputs.view(-1))
  264. print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
  265. optimizer.zero_grad()
  266. loss.backward()
  267. optimizer.step()
  268. def greedy_decoder(model, enc_input, start_symbol):
  269. """
  270. For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
  271. target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
  272. Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
  273. :param model: Transformer Model
  274. :param enc_input: The encoder input
  275. :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
  276. :return: The target input
  277. """
  278. enc_outputs, enc_self_attns = model.encoder(enc_input)
  279. dec_input = torch.zeros(1, 0).type_as(enc_input.data)
  280. terminal = False
  281. next_symbol = start_symbol
  282. while not terminal:
  283. dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_symbol]],dtype=enc_input.dtype).cuda()],-1)
  284. dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
  285. projected = model.projection(dec_outputs)
  286. prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
  287. next_word = prob.data[-1]
  288. next_symbol = next_word
  289. if next_symbol == tgt_vocab["."]:
  290. terminal = True
  291. print(next_word)
  292. return dec_input
  293. # Test
  294. enc_inputs, _, _ = next(iter(loader))
  295. enc_inputs = enc_inputs.cuda()
  296. for i in range(len(enc_inputs)):
  297. greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab["S"])
  298. predict, _, _, _ = model(enc_inputs[i].view(1, -1), greedy_dec_input)
  299. predict = predict.data.max(1, keepdim=True)[1]
  300. print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict.squeeze()])

总结

这个是第一版,有些代码如果不对,欢迎在评论指正,后续会慢慢改,我也是刚接触transformer,若有些地方理解不对还请指正。我相信如果你自己一行一行代码实操一下对你的理解有很大的帮助,谢谢别忘了点赞哦!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/626058
推荐阅读
相关标签
  

闽ICP备14008679号