赞
踩
之前我曾经在一篇博客中有介绍关于Transformer模型的Encoder部分的mask,在这篇文章中,我打算将Decoder部分的mask机制也补充完整。OK,那么我们进入正题。
首先我们先来看一下Decoder的结构,下面这张图取自google在2017年发的著名文章:Attention Is All You Need,也就是最初提出Transformer结构的文章。
为了专注于Decoder的部分,关于原图中Encoder的部分我并没有截取。如果我们仔细观察Decoder的部分,我们会发现有两个多头注意力模块(Multi-Head Attention),但是第一个明确标注了'Masked'。关于这部分Mask,我们首先来看一下原文的说明。
在解读这段话之前,我们首先来明确一个问题,那就是本质上Transformer解决的也是一个sequence to sequence的问题,通过encoder来对输入序列进行编码(embedding representation)然后通过decoder进行解码,生成一个新的序列。所以本质上decoder就是一个生成器。但是因为生成的对象是一个序列,因此在通过decoder模型进行生成时,decoder是一个位置一个位置来逐步生成的,每一个位置i的预测结果的生成都仅仅依赖于前i个位置构成的sub sequence。
明确了这一点之后,我们再回看原文的解释就比较清楚了,为了保证和在预测阶段的逻辑的一致性,在训练阶段,decoder每一个位置的模型输出都仅仅依赖于这个位置之前的embedding进行加权融合的结果,而序列后面部分的数据在训练时需要保证不被模型看到(Masked)。
但是在实际训练中,我们一般不会选择一个位置一个位置来训练,而是拿整体的目标序列。这里涉及到一个非常重要的点,就是上图中的shifted,比如说,我们的目标序列是[“hi”,“i”,“am”,“frank”,“hu”],那么我们的decoder的input就会是[“SOS”,“hi”,“i’,”am“,”frank“,”hu“],而我们的target则是['hi','i','am','frank','hu',''EOS],其中"SOS"和“EOS”分别作为开头和结尾的标示符。然后根据input与target的对应关系,我们的训练逻辑就是,1:给出‘SOS’,希望预测得到‘hi’。 2.给出[‘SOS’,‘hi’],希望预测得到[‘hi ’,'i']。3.给出['SOS','hi','i'],希望预测得到['hi','i','am']... 也就是保证了在预测每个位置的目标时,都只使用目标序列中在该位置之前的所有信息,来进行Attention融合。
下图将上面解释的mask应用在attention的结果上,每行(不包括SOS)代表每个target word在被预测时能够被允许看到的序列部分(列)。
下面用代码来实现一下这部分mask.
- import torch
- def seq_mask(seq):
- batch_size,seq_len=seq.size()
- sub_seq_mask = (1 - torch.triu(torch.ones((1,seq_len,seq_len)),diagonal=1)).bool()
- return sub_seq_mask
但是除了这个意义上的mask,在实做上,我们还需要对于padding的mask,这部分和encoder是一样的,本质上就是因为输入序列部分是padding填充的,没有时间意义,在attention时也不应该考虑,因此同样需要mask这部分。
同样贴一下代码
- import torch
- def get_pad_mask(seq,pad_idx):
- return (seq!=pad_idx).unsqueeze(-2)
OK,以上就是关于Decoder部分的mask的解读,我自己认为说的还是不够清楚,可能是自己的理解还不到家,如果朋友们有任何问题,欢迎在评论区留言,一起交流进步。最后谢谢大家的阅读。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。