赞
踩
mask 表示掩码,它对某些值进行掩盖,使其在参数更新时不产生效果。
因为每个批次输入序列长度是不一样,需要对输入序列进行对齐。给较短的序列后面填充 0,对于输入太长的序列,截取左边的内容,把多余的直接舍弃。这些填充的位置,没什么意义,所以我们的attention机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。
具体的做法是,把这些位置的值加上一个非常大的负数(负无穷),这样的话,经过 softmax,这些位置的概率就会接近0!
而我们的 padding mask 实际上是一个张量,每个值都是一个Boolean,值为 false 的地方就是我们要进行处理的地方。
sequence mask 是为了使得 decoder 不能看见未来的信息。也就是对于一个序列,在 time_step 为 t 的时刻,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。
具体做法:产生一个上三角矩阵,上三角的值全为0。把这个矩阵作用在每一个序列上。
对于 decoder 的 self-attention,同时需要padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个mask相加作为attn_mask。
其他情况,attn_mask 一律等于 padding mask。
代码:
- import torch
-
- def padding_mask(seq, pad_idx):
- return (seq != pad_idx).unsqueeze(-2) # [B, 1, L]
-
- def sequence_mask(seq):
- batch_size, seq_len = seq.size()
- mask = 1- torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),diagonal=1)
- mask = mask.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, L]
- return mask
-
- def test():
- # 以最简化的形式测试Transformer的两种mask
- seq = torch.LongTensor([[1,2,0]]) # batch_size=1, seq_len=3,padding_idx=0
- embedding = torch.nn.Embedding(num_embeddings=3, embedding_dim=10, padding_idx=0)
- query, key = embedding(seq), embedding(seq)
- scores = torch.matmul(query, key.transpose(-2, -1))
-
- mask_p = padding_mask(seq, 0)
- mask_s = sequence_mask(seq)
- mask_decoder = mask_p & mask_s # 结合 padding mask 和 sequence mask
-
- scores_encoder = scores.masked_fill(mask_p==0, -1e9) # 对于scores,在mask==0的位置填充
- scores_decoder = scores.masked_fill(mask_decoder==0, -1e9)
-
- test()
参考:Transform详解_霜叶的博客-CSDN博客_transform
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。