当前位置:   article > 正文

Transform 相关知识(Mask)_sequence mask

sequence mask

1. Mask

 mask 表示掩码,它对某些值进行掩盖,使其在参数更新时不产生效果。

  • padding mask:处理非定长序列,区分padding和非padding部分,如在RNN等模型和Attention机制中的应用等
  • equence mask:防止标签泄露,如:Transformer decoder中的mask矩阵,BERT中的[Mask]位,XLNet中的mask矩阵等

1.1 Padding Mask

因为每个批次输入序列长度是不一样,需要对输入序列进行对齐。给较短的序列后面填充 0,对于输入太长的序列,截取左边的内容,把多余的直接舍弃。这些填充的位置,没什么意义,所以我们的attention机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。

具体的做法是,把这些位置的值加上一个非常大的负数(负无穷),这样的话,经过 softmax,这些位置的概率就会接近0!

而我们的 padding mask 实际上是一个张量,每个值都是一个Boolean,值为 false 的地方就是我们要进行处理的地方。

1.2 Sequence mask


sequence mask 是为了使得 decoder 不能看见未来的信息。也就是对于一个序列,在 time_step 为 t 的时刻,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。

具体做法:产生一个上三角矩阵,上三角的值全为0。把这个矩阵作用在每一个序列上。

对于 decoder 的 self-attention,同时需要padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个mask相加作为attn_mask。
其他情况,attn_mask 一律等于 padding mask。

代码:

  1. import torch
  2. def padding_mask(seq, pad_idx):
  3. return (seq != pad_idx).unsqueeze(-2) # [B, 1, L]
  4. def sequence_mask(seq):
  5. batch_size, seq_len = seq.size()
  6. mask = 1- torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),diagonal=1)
  7. mask = mask.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, L]
  8. return mask
  9. def test():
  10. # 以最简化的形式测试Transformer的两种mask
  11. seq = torch.LongTensor([[1,2,0]]) # batch_size=1, seq_len=3,padding_idx=0
  12. embedding = torch.nn.Embedding(num_embeddings=3, embedding_dim=10, padding_idx=0)
  13. query, key = embedding(seq), embedding(seq)
  14. scores = torch.matmul(query, key.transpose(-2, -1))
  15. mask_p = padding_mask(seq, 0)
  16. mask_s = sequence_mask(seq)
  17. mask_decoder = mask_p & mask_s # 结合 padding mask 和 sequence mask
  18. scores_encoder = scores.masked_fill(mask_p==0, -1e9) # 对于scores,在mask==0的位置填充
  19. scores_decoder = scores.masked_fill(mask_decoder==0, -1e9)
  20. test()

参考:Transform详解_霜叶的博客-CSDN博客_transform

           NLP 中的Mask全解 - 知乎

    

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/307027
推荐阅读
相关标签
  

闽ICP备14008679号