赞
踩
本篇文章适合已经仔细阅读过Transformer论文和源码的同学,使用的代码为Pytorch版,且只是记录本人在学习过程中的个人理解,如果出现错误的地方,请在评论区友好交流,感谢指正。
- def forward(self, src_seq, trg_seq):
-
- src_mask = get_pad_mask(src_seq, self.src_pad_idx) #只对有效长度进行attention计算,pad的0需要mask
- trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq) #不仅mask padding部分,还mask上三角
-
- enc_output, *_ = self.encoder(src_seq, src_mask)
- dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask) #这个trg_seq是指目标语言
- seq_logit = self.trg_word_prj(dec_output)
- if self.scale_prj:
- seq_logit *= self.d_model ** -0.5
-
- return seq_logit.view(-1, seq_logit.size(2))
'运行
我们知道src_seq是句子的完整序列,而trg_seq是目标序列,即解码器一步一个字解码出来的。src_mask是消除掉为了补齐长度而用来padding的元素对注意力的影响得到的mask(对应函数get_pad_mask),而trg_mask除了此种掩码之外,还用到了一个消除掉暂时还未解码出的字的影响的掩码机制(对应函数get_subsequent_mask)。下面我们看看encoder和decoder中,这些掩码机制到底是怎么工作的。
假设src_seq的尺寸为batch_size=2,seq_len=3,0为padding的即需要进行掩码的元素。我们以encoder里注意力机制中的掩码机制为例,举一个运行原理相同,但简单易理解的例子:
- import torch
-
- def get_pad_mask(seq, pad_idx):
- return (seq != pad_idx).unsqueeze(-2)
-
- if __name__ == "__main__":
- seq = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3, padding_idx=0, torch.Size([2, 3])
- embedding = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0) # 对每一个字进行编码,每个字维度为10
- query, key = embedding(seq), embedding(seq) # torch.Size([2, 3, 10]), torch.Size([2, 3, 10])
- att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 3, 3])
- """
- att:
- tensor([[[6.3899, 0.5517, 0.0000],
- [0.5517, 6.4545, 0.0000],
- [0.0000, 0.0000, 0.0000]],
- [[14.2199, -1.2504, -3.7615],
- [-1.2504, 8.2810, 0.3213],
- [-3.7615, 0.3213, 12.7485]]])
- """
- mask = get_pad_mask(seq, 0) # torch.Size([2, 1, 3])
- """
- mask:
- tensor([[[True, True, False]],
- [[True, True, True]]])
- """
- masked_att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 3, 3])
- """
- masked_att:
- tensor([[[6.3899e+00, 5.5172e-01, -1.0000e+09],
- [5.5172e-01, 6.4545e+00, -1.0000e+09],
- [0.0000e+00, 0.0000e+00, -1.0000e+09]],
- [[1.4220e+01, -1.2504e+00, -3.7615e+00],
- [-1.2504e+00, 8.2810e+00, 3.2127e-01],
- [-3.7615e+00, 3.2127e-01, 1.2749e+01]]])
- """
注意到
masked_att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 3, 3])
这行代码用到了广播机制,即将mask的尺寸由[2,1,3]广播为[2,3,3],然后与注意力矩阵att对应,广播之后的mask如下所示:
- """
- mask:
- tensor([[[True, True, False],
- [True, True, False],
- [True, True, False]],
- [[True, True, True],
- [True, True, True],
- [True, True, True]]])
- """
'运行
- class DecoderLayer(nn.Module):
- ''' Compose with three layers '''
-
- def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
- super(DecoderLayer, self).__init__()
- self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
- self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
- self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
-
- def forward(
- self, dec_input, enc_output,
- slf_attn_mask=None, dec_enc_attn_mask=None):
- dec_output, dec_slf_attn = self.slf_attn(
- dec_input, dec_input, dec_input, mask=slf_attn_mask)
- dec_output, dec_enc_attn = self.enc_attn(
- dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
- dec_output = self.pos_ffn(dec_output)
- return dec_output, dec_slf_attn, dec_enc_attn
上面是代码中对应decoder的部分(Layers.py),self.slf_attn和self.enc_attn分别对应第一个和第二个sublayer。 阅读源代码后我们发现,首先,传入forward的参数中,slf_attn_mask对应传入的是trg_mask,dec_enc_attn_mask对应传入的是src_mask。src_mask是通过get_pad_mask函数得到的掩码,而trg_mask是通过get_pad_mask和dec_enc_attn_mask两个函数得到的掩码。我们模拟第一个sublayer,举一个例子:
- import torch
-
- def get_pad_mask(seq, pad_idx):
- return (seq != pad_idx).unsqueeze(-2)
-
- def get_subsequent_mask(seq):
- batch_size, seq_len = seq.size()
- mask = 1- torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),diagonal=1)
- mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
- return mask
-
- if __name__ == "__main__":
- seq = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3, padding_idx=0, torch.Size([2, 3])
- embedding = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0) # 对每一个字进行编码,每个字维度为10
- query, key = embedding(seq), embedding(seq) # torch.Size([2, 3, 10]), torch.Size([2, 3, 10])
- att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 3, 3])
- """
- att:
- tensor([[[6.3899, 0.5517, 0.0000],
- [0.5517, 6.4545, 0.0000],
- [0.0000, 0.0000, 0.0000]],
- [[14.2199, -1.2504, -3.7615],
- [-1.2504, 8.2810, 0.3213],
- [-3.7615, 0.3213, 12.7485]]])
- """
- p_mask = get_pad_mask(seq, 0) # torch.Size([2, 1, 3])
- """
- p_mask:
- tensor([[[True, True, False]],
- [[True, True, True]]])
- """
- s_mask = get_subsequent_mask(seq) # torch.Size([2, 3, 3])
- """
- s_mask:
- tensor([[[1, 0, 0],
- [1, 1, 0],
- [1, 1, 1]],
- [[1, 0, 0],
- [1, 1, 0],
- [1, 1, 1]]])
- """
- mask = p_mask & s_mask # torch.Size([2, 3, 3])
- """
- tensor([[[1, 0, 0],
- [1, 1, 0],
- [1, 1, 0]],
- [[1, 0, 0],
- [1, 1, 0],
- [1, 1, 1]]])
- """
- masked_att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 3, 3])
- """
- masked_att:
- tensor([[[ 7.0830e+00, -1.0000e+09, -1.0000e+09],
- [-2.7239e+00, 1.5791e+01, -1.0000e+09],
- [ 0.0000e+00, 0.0000e+00, -1.0000e+09]],
-
- [[ 1.4356e+01, -1.0000e+09, -1.0000e+09],
- [ 4.3740e+00, 9.3937e+00, -1.0000e+09],
- [ 5.1368e+00, 1.1749e+00, 5.1988e+00]]])
- """
与encoder中的mask机制相比,这里只是多了一个上三角的掩码,其实这个例子为了更严谨,序列长度最好设的比encoder中短,这里偷了一个懒。
第二个sublayer中用的就是和encoder中一样的掩码了,没有用到上三角掩码,注意这里用的Q来自decoder,K和V来自encoder,对于这里的注意力机制,举一个例子:
- import torch
-
- def get_pad_mask(seq, pad_idx):
- return (seq != pad_idx).unsqueeze(-2)
-
- if __name__ == "__main__":
- seq_k = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3,padding_idx=0 torch.Size([2, 3])
- seq_q = torch.LongTensor([[4, 5], [6, 7]]) # batch_size=2, seq_len=2,padding_idx=0 torch.Size([2, 2])
- embedding_k = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0)
- embedding_q = torch.nn.Embedding(num_embeddings=8, embedding_dim=10, padding_idx=0)
- query = embedding_q(seq_q) # torch.Size([2, 3, 10])
- key = embedding_k(seq_k) # torch.Size([2, 3, 10])
- att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 2, 3])
- print(att)
- """
- att:
- tensor([[[-2.6561, -3.2418, 0.0000],
- [ 4.2412, -2.5950, 0.0000]],
-
- [[ 3.1960, 9.1766, 0.6027],
- [-4.0462, -1.4987, -0.4528]]])
- """
- mask = get_pad_mask(seq_k, 0) # torch.Size([2, 1, 3])
- """
- tensor([[[ True, True, False]],
- [[ True, True, True]]])
- """
- att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 2, 3])
- """
- tensor([[[-2.3804e-04, -1.8567e-01, -1.0000e+09],
- [ 4.5441e-01, 1.8053e-01, -1.0000e+09]],
- [[-1.8674e+00, 2.6307e+00, 2.6570e+00],
- [-1.5631e+00, 2.2473e-02, 3.7925e+00]]])
- """
原理与之前类似,只不过Q的序列长度会与K不同。
参考链接:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。