当前位置:   article > 正文

Multi Query Attention & Group Query Attention_multi-query attention

multi-query attention

Multi Query Attention(MQA)在2019年就被提出来了,用于推理加速,但在当时并没有受到很多关注,毕竟一张2080就能跑Bert-base了。随着LLM的大火,MQA所带来的收益得以放大。

思路

Multi Query Attention(MQA)跟Multi Head Attention(MHA)只有一词之差,但其思路非常简单,几乎跟MHA一致:

model

MHA的Query、Key、Value分拆成8个头,每个头进行self-attention运算,而MQA是Query分成8个头,每个头共享一组Key和Value

MHA: Q, K, V = (512, 768), # seq_len, hidden_dim
			拆成8个头:
			Q : (8, 512, 96) 
			k, v: (8, 512, 96)
MQA: 
 Q -> (512, 768) 
 K -> (512, 96)
 v -> (512, 96)
把Q拆成8个头:
Q: (8, 512, 96)
K, V:(512, 96)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

代码实现

  • MHA
...
self.Wqkv = nn.Linear( 
            d_model,
            d_model * 3,
            device=device,
        )
...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

d_model * 3 拆成3个768维

  • MQA
...
self.Wqkv = nn.Linear( 
            d_model,
            d_model + 2 * self.head_dim,
            device=device,
        )
...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

d_model + 2 * self.head_dim 拆成1个768维 + 2个96维

可以看到参数数量大幅减少。

实验结果

实验指标略微降低,但推理加速非常明显。

result

Group Query Attention

Q拆分成8个头,K和V分别拆成4个头,然后对应进行attention运算。


参考

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/439854
推荐阅读
相关标签
  

闽ICP备14008679号