当前位置:   article > 正文

nlp中的transformer中的mask

nlp中的transformer中的mask

由于在实现多头注意力时需要考虑到各种情况下的掩码,因此在这里需要先对这部分内容进行介绍。在Transformer中,主要有两个地方会用到掩码这一机制。第1个地方就是在上一篇文章用介绍到的Attention Mask,用于在训练过程中解码的时候掩盖掉当前时刻之后的信息;第2个地方便是对一个batch中不同长度的序列在Padding到相同长度后,对Padding部分的信息进行掩盖。下面分别就这两种情况进行介绍。

1.Attention Mask

实现:generate_square_subsequent_mask

  1. def _generate_square_subsequent_mask(self, sz):
  2. mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  3. mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  4. return mask

2.Padding Mask

实现:

用法:

https://blog.csdn.net/vivi_cin/article/details/135390462

参考:

nn.TransformerEncoderLayer中的src_mask,src_key_padding_mask解析_src_mask和src_key_padding_mask-CSDN博客

(32 封私信 / 4 条消息) transformer中: self-attention部分是否需要进行mask? - 知乎 (zhihu.com) 几个很好的回答:

Q:transformer中attention_mask一定需要嘛?

A:Transformer结构包括编码器和解码器,在编码过程中目的就是为了让模型看到当前位置前后的信息,所以不需要attention mask。但是在解码过程中为了模拟在真实的inference场景中,当前位置看不到下一位置,且同时需要上一位置的信息,所以在训练的时候加了attention mask。

所以,如果你的任务在实际的inference中也符合这样的特点,那么你在训练的时候也是需要attention,相反则不需要。

参考:(32 封私信 / 4 条消息) transformer中attention_mask一定需要嘛? - 知乎 (zhihu.com)

还有一个写的很好的博主:

 nn.TransformerEncoderLayer中的src_mask,src_key_padding_mask解析_src_mask和src_key_padding_mask-CSDN博客

参考的github上关于pad mask 实现 :

https://github.com/HIT-SCIR/plm-nlp-code/blob/64564b643a09cb85163ccca1f8c41fc94f5fc9ec/chp4/utils.py

关键代码:

  1. def length_to_mask(lengths):
  2. max_len = torch.max(lengths)
  3. mask = torch.arange(max_len, device=lengths.device).expand(lengths.shape[0], max_len) < lengths.unsqueeze(1)
  4. return mask

 model:

  1. class Transformer(nn.Module):
  2. def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
  3. dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
  4. super(Transformer, self).__init__()
  5. # 词嵌入层
  6. self.embedding_dim = embedding_dim
  7. self.embeddings = nn.Embedding(vocab_size, embedding_dim)
  8. self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
  9. # 编码层:使用Transformer
  10. encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
  11. self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
  12. # 输出层
  13. self.output = nn.Linear(hidden_dim, num_class)
  14. def forward(self, inputs, lengths):
  15. inputs = torch.transpose(inputs, 0, 1)
  16. hidden_states = self.embeddings(inputs)
  17. hidden_states = self.position_embedding(hidden_states)
  18. attention_mask = length_to_mask(lengths) == False
  19. hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
  20. logits = self.output(hidden_states)
  21. log_probs = F.log_softmax(logits, dim=-1)
  22. return log_probs

模型完整代码:

  1. class Transformer(nn.Module):
  2. def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
  3. dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
  4. super(Transformer, self).__init__()
  5. # 词嵌入层
  6. self.embedding_dim = embedding_dim
  7. self.embeddings = nn.Embedding(vocab_size, embedding_dim)
  8. self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
  9. # 编码层:使用Transformer
  10. encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
  11. self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
  12. # 输出层
  13. self.output = nn.Linear(hidden_dim, num_class)
  14. def forward(self, inputs, lengths):
  15. inputs = torch.transpose(inputs, 0, 1)
  16. hidden_states = self.embeddings(inputs)
  17. hidden_states = self.position_embedding(hidden_states)
  18. attention_mask = length_to_mask(lengths) == False
  19. hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
  20. logits = self.output(hidden_states)
  21. log_probs = F.log_softmax(logits, dim=-1)
  22. return log_probs

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/564384
推荐阅读
相关标签
  

闽ICP备14008679号