当前位置:   article > 正文

pytorch笔记:nn.MultiheadAttention

nn.multiheadattention

1 函数介绍

  1. torch.nn.MultiheadAttention(
  2. embed_dim,
  3. num_heads,
  4. dropout=0.0,
  5. bias=True,
  6. add_bias_kv=False,
  7. add_zero_attn=False,
  8. kdim=None,
  9. vdim=None,
  10. batch_first=False,
  11. device=None,
  12. dtype=None)

2 参数介绍

embed_dim模型的维度
num_heads

attention的头数

(embed_dim会平均分配给每个头,也即每个头的维度是embed_dim//num_heads)

dropoutattn_output_weights的dropout概率
biasinput和output的投影函数,是否有bias
kdim

k的维度,默认embed_dim

vdimv的维度,默认embed_dim
batch_firstTrue——输入和输出的维度是(batch_num,seq_len,feature_dim)
False——输入和输出的维度是(batch_num,seq_len,feature_dim)

3 forward函数

  1. forward(
  2. query,
  3. key,
  4. value,
  5. key_padding_mask=None,
  6. need_weights=True,
  7. attn_mask=None,
  8. average_attn_weights=True)

4 forward函数参数介绍

query
  • 对于没有batch的输入,维度是(length,embed_dim)
  • 对于有batch的输入,维度是(batch_num,len,embed_dim)或者(len,batch_num,embed_dim)【取决于batch_first】
key
  • 对于没有batch的输入,维度是(S_length,kdim)
  • 对于有batch的输入,维度是(batch_num,len,kdim)或者(len,batch_num,kdim)【取决于batch_first】
value
  • 对于没有batch的输入,维度是(S_length,vdim)
  • 对于有batch的输入,维度是(batch_num,len,vdim)或者(len,batch_num,vdim)【取决于batch_first】
key_padding_mask 

如果设置,那么

  • 对于没有batch的输入,这需要一个S_length大小的mask向量
  • 对于有batch的输入,这需要一个(length,S_length)大小的mask矩阵

True表示对应的key value在计算attention的时候,需要被忽略

need_weights如果设置,那么返回值会多一个attn_output_weight
attn_maskTrue表示对应的attention value 不应该存在
average_attn_weights 

如果设置,那么返回的是各个头的平均attention weight

否则,就是把所有的head分别输出

5 forward输出

attn_output
  • 对于没有batch的输入,维度为(length,embed_dim)
  • 对于有batch的输入,维度为(length,batch_size,embed_dim)或(batch_size,length,embed_dim)
attn_output_weight
  • 对于没有batch的输入
    • 如果average_attn_weights为True,那么就是(length,S_length);否则是(num_heads,length,S_length)

6 举例

  1. import torch
  2. import torch.nn as nn
  3. lst=torch.Tensor([[1,2,3,4],
  4. [2,3,4,5],
  5. [7,8,9,10]])
  6. lst=lst.unsqueeze(1)
  7. lst.shape
  8. #torch.Size([3, 1, 4])
  9. multi_atten=nn.MultiheadAttention(embed_dim=4,
  10. num_heads=2)
  11. multi_atten(lst,lst,lst)
  12. '''
  13. (tensor([[[ 1.9639, -3.7282, 2.1215, 0.6630]],
  14. [[ 2.2423, -4.2444, 2.2466, 1.0711]],
  15. [[ 2.3823, -4.5058, 2.3015, 1.2964]]], grad_fn=<AddBackward0>),
  16. tensor([[[9.0335e-02, 1.2198e-01, 7.8769e-01],
  17. [2.6198e-02, 4.4854e-02, 9.2895e-01],
  18. [1.6031e-05, 9.4658e-05, 9.9989e-01]]], grad_fn=<DivBackward0>))
  19. '''

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/346710
推荐阅读
相关标签
  

闽ICP备14008679号