当前位置:   article > 正文

大模型面经—GQA(Grouped Query Attention)和MHA、MQA的区别及代码_grouped query attention 代码

grouped query attention 代码

大模型面经—GQA(Grouped Query Attention)和MHA、MQA的区别及代码

瓦力算法学研所 2024年07月07日 10:43 安徽

技术总结专栏

本篇介绍分组查询注意力机制。

在大模型技术中,GQA(Grouped Query Attention)是一种注意力机制,它介于MHA(Multi-Head Attention)和MQA(Multi-Query Attention)之间,旨在结合两者的优点,以实现在保持MQA推理速度的同时接近MHA的精度 。

MHA是一种基础的注意力机制,它通过将输入分割成多个头(heads)来并行计算注意力,每个头学习输入的不同部分,最终将结果合并,以捕获序列的不同方面信息 。

MQA则是一种优化的注意力机制,它通过所有头共享相同的键(keys)和值(values),减少了参数量和计算量,从而加快了推理速度,但可能会牺牲一些精度 。

GQA作为MHA和MQA的折中方案,它将查询头(query heads)分组,每组共享一个键和值,而不是所有头都共享。这样,GQA能够在减少计算量的同时,保持更多的多样性,从而在推理速度和模型精度之间取得平衡 。

  • GQA-1:一个单独的组,等同于 Multi-Query Attention (MQA)

  • GQA-H组数等于头数,基本上与 Multi-Head Attention (MHA) 相同。

  • GQA-G:一个中间配置,具有G个组,平衡了效率和表达能力。

图片

具体来说,GQA通过分组的方式,减少了需要缓存的键和值的数量,从而减少了内存的使用,同时由于不是所有头都共享键和值,它能够比MQA更好地保持MHA的多样性和精度 。例如,如果GQA使用2个头的键和值,那么每个组包含4个查询头,这样在保持速度的同时,精度损失会比MQA小 。

此外,GQA的实现并不复杂,可以通过对现有MHA模型进行少量的训练调整来实现,这使得从MHA到GQA的过渡相对容易 。GQA已经在一些大型语言模型中得到应用,例如Meta开源的LLAMA系列模型 。

总结来说,GQA是一种有效的注意力机制,它通过在MHAMQA之间进行插值,旨在实现更快的推理速度和接近MHA的模型质量,是高负载系统优化的有力工具 。

class  MultiQueryAttention(Attention):    r"""    https://arxiv.org/pdf/1911.02150.pdf    """    def __init__(self, word_size: int = 512, embed_dim: int = 64, n_query:int=8) -> None:        super().__init__(word_size, embed_dim)        self.n_query = n_query        self.proj = nn.Linear(in_features=embed_dim * n_query,                              out_features=embed_dim, bias=False)        delattr(self, 'query')        self.querys = nn.ModuleList([            nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)            for _ in range(n_query)        ])        self.key = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)        self.value = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
    def forward(self, x: Tensor, mask:Optional[BoolTensor]=None) -> Tensor:        K = self.key(x)        V = self.value(x)        Z_s = torch.cat([            self.self_attention(query(x), K, V, mask) for query in self.querys        ], dim=1)        Z = self.proj(Z_s)        return Z

class  GroupedQueryAttention(Attention):    r"""    https://arxiv.org/pdf/2305.13245.pdf    """    def __init__(self, word_size: int = 512, embed_dim: int = 64,                 n_grouped: int = 4, n_query_each_group:int=2) -> None:        super().__init__(word_size, embed_dim)        delattr(self, 'query')        delattr(self, 'key')        delattr(self, 'value')
        self.grouped = nn.ModuleList([            MultiQueryAttention(word_size, embed_dim, n_query=n_query_each_group)            for _ in range(n_grouped)        ])        self.proj = nn.Linear(in_features=embed_dim * n_grouped,                              out_features=embed_dim, bias=False)
    def forward(self, x: Tensor, mask:Optional[BoolTensor]=None) -> Tensor:        Z_s = torch.cat([head(x, mask) for head in self.grouped], dim=1)        Z = self.proj(Z_s)        return Z

想要获取技术资料的同学欢迎关注公众号,进群一起交流~

参考文献:

https://arxiv.org/pdf/2305.13245

https://cyrilzakka.github.io/llm-playbook/nested/gqa.html

https://github.com/knotgrass/attention/blob/main/attn/attention.py

https://blog.csdn.net/baoyan2015/article/details/137968408

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/936035
推荐阅读
相关标签
  

闽ICP备14008679号