当前位置:   article > 正文

Grouped Query Attention论文阅读_group query attention

group query attention

论文:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

1. 背景介绍

Google在2023年发表的一篇关于Transformer Attention的论文,整体论文写的清晰易读,思想简单但很好用。论文名字简写是GQA,但实际分别代表了两种缩写:

  1. Generalized Multi Query Attention
  2. Grouped Query Attention

2. 详细介绍

2.1 通用Multi-Query Attention

在之前的Multi-Query Attention M Q A MQA MQA】方法中只会保留一个单独的key-value头,这样虽然可以提升推理的速度,但是会带来精度的损失,这篇论文的第一个思路是基于多个 M Q A MQA MQA的checkpoint进行finetuning,来得到了一个质量更高的 M Q A MQA MQA模型。这个过程也被称为Uptraining

具体分为两步:

  1. 对多个 M Q A MQA MQA的checkpoint文件进行融合,融合的方法是通过对key和value的head头进行mean pooling操作,如下图。
  2. 对融合后的模型使用少量数据进行finetune训练,重训后的模型大小跟之前一样,但是效果会更好

在这里插入图片描述

2.2 Grouped-query attention

如下图所示,在一般的attention中是Multi-head多头结构,每个头有自己单独的key-value对;在Multi-query attention结构中只会有一组key-value对;在Grouped-query attention对attention进行分组操作,query被分为N组,每一组分别与一对key-value对进行映射。

在基于Multi-head多头结构变为Grouped-query分组结构的时候,也是采用跟2.1一样的方法,对每一组的key-value对进行mean pool的操作进行参数融合。融合后的模型能力更综合,精度比Multi-query好,同时速度比Multi-head快。
在这里插入图片描述

3. 应用

在llama2中有用到GQA, 在推理过程中由于多个query会复用相同的key-value对,所以对于KV-Cache存储会减少对key-value对的存储,减少了 n_heads / n_kv_heads 倍,这里的n_heads是原始Multi-head的头数,n_kv_headsGrouped-query分组后每组中key-value对的数量。

在实际使用中,会根据从压缩后的key-value对进行还原操作,也就是repeat操作。在llama2中代码如下:

# https://github.com/facebookresearch/llama/blob/4d92db8a1db6c7f663252bf3477d2c4b8bad2385/llama/model.py#L77
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

repeat_kv在github上一个示例如下, 假设repeat 2次操作:

>>> x = torch.rand(1, 2, 3, 4)
>>> x
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])
>>> n_rep = 2
>>> bs, slen, n_kv_heads, head_dim = x.shape
>>> x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

4. 参考

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/936044
推荐阅读
相关标签
  

闽ICP备14008679号