当前位置:   article > 正文

Transformer中mask机制理解_transformer中的src,mask以及poss(

transformer中的src,mask以及poss(

提示

本篇文章适合已经仔细阅读过Transformer论文和源码的同学,使用的代码为Pytorch版,且只是记录本人在学习过程中的个人理解,如果出现错误的地方,请在评论区友好交流,感谢指正。

前置内容

  1. def forward(self, src_seq, trg_seq):
  2. src_mask = get_pad_mask(src_seq, self.src_pad_idx) #只对有效长度进行attention计算,pad的0需要mask
  3. trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq) #不仅mask padding部分,还mask上三角
  4. enc_output, *_ = self.encoder(src_seq, src_mask)
  5. dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask) #这个trg_seq是指目标语言
  6. seq_logit = self.trg_word_prj(dec_output)
  7. if self.scale_prj:
  8. seq_logit *= self.d_model ** -0.5
  9. return seq_logit.view(-1, seq_logit.size(2))
'
运行

我们知道src_seq是句子的完整序列,而trg_seq是目标序列,即解码器一步一个字解码出来的。src_mask是消除掉为了补齐长度而用来padding的元素对注意力的影响得到的mask(对应函数get_pad_mask),而trg_mask除了此种掩码之外,还用到了一个消除掉暂时还未解码出的字的影响的掩码机制(对应函数get_subsequent_mask)。下面我们看看encoder和decoder中,这些掩码机制到底是怎么工作的。

encoder中多头注意力的掩码机制

假设src_seq的尺寸为batch_size=2,seq_len=3,0为padding的即需要进行掩码的元素。我们以encoder里注意力机制中的掩码机制为例,举一个运行原理相同,但简单易理解的例子:

  1. import torch
  2. def get_pad_mask(seq, pad_idx):
  3. return (seq != pad_idx).unsqueeze(-2)
  4. if __name__ == "__main__":
  5. seq = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3, padding_idx=0, torch.Size([2, 3])
  6. embedding = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0) # 对每一个字进行编码,每个字维度为10
  7. query, key = embedding(seq), embedding(seq) # torch.Size([2, 3, 10]), torch.Size([2, 3, 10])
  8. att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 3, 3])
  9. """
  10. att:
  11. tensor([[[6.3899, 0.5517, 0.0000],
  12. [0.5517, 6.4545, 0.0000],
  13. [0.0000, 0.0000, 0.0000]],
  14. [[14.2199, -1.2504, -3.7615],
  15. [-1.2504, 8.2810, 0.3213],
  16. [-3.7615, 0.3213, 12.7485]]])
  17. """
  18. mask = get_pad_mask(seq, 0) # torch.Size([2, 1, 3])
  19. """
  20. mask:
  21. tensor([[[True, True, False]],
  22. [[True, True, True]]])
  23. """
  24. masked_att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 3, 3])
  25. """
  26. masked_att:
  27. tensor([[[6.3899e+00, 5.5172e-01, -1.0000e+09],
  28. [5.5172e-01, 6.4545e+00, -1.0000e+09],
  29. [0.0000e+00, 0.0000e+00, -1.0000e+09]],
  30. [[1.4220e+01, -1.2504e+00, -3.7615e+00],
  31. [-1.2504e+00, 8.2810e+00, 3.2127e-01],
  32. [-3.7615e+00, 3.2127e-01, 1.2749e+01]]])
  33. """

注意到

masked_att = att.masked_fill(mask==0, -1e9)     # torch.Size([2, 3, 3])

这行代码用到了广播机制,即将mask的尺寸由[2,1,3]广播为[2,3,3],然后与注意力矩阵att对应,广播之后的mask如下所示:

  1. """
  2. mask:
  3. tensor([[[True, True, False],
  4. [True, True, False],
  5. [True, True, False]],
  6. [[True, True, True],
  7. [True, True, True],
  8. [True, True, True]]])
  9. """
'
运行

decoder中第一个多头注意力的掩码机制

  1. class DecoderLayer(nn.Module):
  2. ''' Compose with three layers '''
  3. def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
  4. super(DecoderLayer, self).__init__()
  5. self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  6. self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  7. self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
  8. def forward(
  9. self, dec_input, enc_output,
  10. slf_attn_mask=None, dec_enc_attn_mask=None):
  11. dec_output, dec_slf_attn = self.slf_attn(
  12. dec_input, dec_input, dec_input, mask=slf_attn_mask)
  13. dec_output, dec_enc_attn = self.enc_attn(
  14. dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
  15. dec_output = self.pos_ffn(dec_output)
  16. return dec_output, dec_slf_attn, dec_enc_attn

 上面是代码中对应decoder的部分(Layers.py),self.slf_attn和self.enc_attn分别对应第一个和第二个sublayer。 阅读源代码后我们发现,首先,传入forward的参数中,slf_attn_mask对应传入的是trg_mask,dec_enc_attn_mask对应传入的是src_mask。src_mask是通过get_pad_mask函数得到的掩码,而trg_mask是通过get_pad_mask和dec_enc_attn_mask两个函数得到的掩码。我们模拟第一个sublayer,举一个例子:

  1. import torch
  2. def get_pad_mask(seq, pad_idx):
  3. return (seq != pad_idx).unsqueeze(-2)
  4. def get_subsequent_mask(seq):
  5. batch_size, seq_len = seq.size()
  6. mask = 1- torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),diagonal=1)
  7. mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
  8. return mask
  9. if __name__ == "__main__":
  10. seq = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3, padding_idx=0, torch.Size([2, 3])
  11. embedding = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0) # 对每一个字进行编码,每个字维度为10
  12. query, key = embedding(seq), embedding(seq) # torch.Size([2, 3, 10]), torch.Size([2, 3, 10])
  13. att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 3, 3])
  14. """
  15. att:
  16. tensor([[[6.3899, 0.5517, 0.0000],
  17. [0.5517, 6.4545, 0.0000],
  18. [0.0000, 0.0000, 0.0000]],
  19. [[14.2199, -1.2504, -3.7615],
  20. [-1.2504, 8.2810, 0.3213],
  21. [-3.7615, 0.3213, 12.7485]]])
  22. """
  23. p_mask = get_pad_mask(seq, 0) # torch.Size([2, 1, 3])
  24. """
  25. p_mask:
  26. tensor([[[True, True, False]],
  27. [[True, True, True]]])
  28. """
  29. s_mask = get_subsequent_mask(seq) # torch.Size([2, 3, 3])
  30. """
  31. s_mask:
  32. tensor([[[1, 0, 0],
  33. [1, 1, 0],
  34. [1, 1, 1]],
  35. [[1, 0, 0],
  36. [1, 1, 0],
  37. [1, 1, 1]]])
  38. """
  39. mask = p_mask & s_mask # torch.Size([2, 3, 3])
  40. """
  41. tensor([[[1, 0, 0],
  42. [1, 1, 0],
  43. [1, 1, 0]],
  44. [[1, 0, 0],
  45. [1, 1, 0],
  46. [1, 1, 1]]])
  47. """
  48. masked_att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 3, 3])
  49. """
  50. masked_att:
  51. tensor([[[ 7.0830e+00, -1.0000e+09, -1.0000e+09],
  52. [-2.7239e+00, 1.5791e+01, -1.0000e+09],
  53. [ 0.0000e+00, 0.0000e+00, -1.0000e+09]],
  54. [[ 1.4356e+01, -1.0000e+09, -1.0000e+09],
  55. [ 4.3740e+00, 9.3937e+00, -1.0000e+09],
  56. [ 5.1368e+00, 1.1749e+00, 5.1988e+00]]])
  57. """

与encoder中的mask机制相比,这里只是多了一个上三角的掩码,其实这个例子为了更严谨,序列长度最好设的比encoder中短,这里偷了一个懒。

decoder中第二个多头注意力的掩码机制

第二个sublayer中用的就是和encoder中一样的掩码了,没有用到上三角掩码,注意这里用的Q来自decoder,K和V来自encoder,对于这里的注意力机制,举一个例子:

  1. import torch
  2. def get_pad_mask(seq, pad_idx):
  3. return (seq != pad_idx).unsqueeze(-2)
  4. if __name__ == "__main__":
  5. seq_k = torch.LongTensor([[1, 2, 0],[3, 4, 5]]) # batch_size=2, seq_len=3,padding_idx=0 torch.Size([2, 3])
  6. seq_q = torch.LongTensor([[4, 5], [6, 7]]) # batch_size=2, seq_len=2,padding_idx=0 torch.Size([2, 2])
  7. embedding_k = torch.nn.Embedding(num_embeddings=6, embedding_dim=10, padding_idx=0)
  8. embedding_q = torch.nn.Embedding(num_embeddings=8, embedding_dim=10, padding_idx=0)
  9. query = embedding_q(seq_q) # torch.Size([2, 3, 10])
  10. key = embedding_k(seq_k) # torch.Size([2, 3, 10])
  11. att = torch.matmul(query, key.transpose(-2, -1)) # torch.Size([2, 2, 3])
  12. print(att)
  13. """
  14. att:
  15. tensor([[[-2.6561, -3.2418, 0.0000],
  16. [ 4.2412, -2.5950, 0.0000]],
  17. [[ 3.1960, 9.1766, 0.6027],
  18. [-4.0462, -1.4987, -0.4528]]])
  19. """
  20. mask = get_pad_mask(seq_k, 0) # torch.Size([2, 1, 3])
  21. """
  22. tensor([[[ True, True, False]],
  23. [[ True, True, True]]])
  24. """
  25. att = att.masked_fill(mask==0, -1e9) # torch.Size([2, 2, 3])
  26. """
  27. tensor([[[-2.3804e-04, -1.8567e-01, -1.0000e+09],
  28. [ 4.5441e-01, 1.8053e-01, -1.0000e+09]],
  29. [[-1.8674e+00, 2.6307e+00, 2.6570e+00],
  30. [-1.5631e+00, 2.2473e-02, 3.7925e+00]]])
  31. """

 原理与之前类似,只不过Q的序列长度会与K不同。

mask机制的代码思想

参考链接:

 NLP 中的Mask全解_mask在自然语言处理代表什么-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/908626
推荐阅读
相关标签
  

闽ICP备14008679号