赞
踩
多query注意(MQA)仅使用单个key头,大大加快了解码器推理速度。然而,MQA可能导致质量下降,而且仅仅为了更快的推理而训练一个单独的模型可能是不可取的。
MHA(Multi-head Attention)是标准的多头注意力机制,h个Query、Key 和 Value 矩阵。
MQA(Multi-Query Attention,Fast Transformer Decoding: One Write-Head is All You Need)是多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。
self.Wqkv = nn.Linear( # 【关键】Multi-Query Attention 的创建方法
d_model,
d_model + 2 * self.head_dim, # 只创建 query 的 head 向量,所以只有 1 个 d_model
device=device, # 而 key 和 value 则只共享各自的一个 head_dim 的向量
)
从MHA生成MQA分为两个步骤:
分组查询注意GQA将query头分为G组,每组共享一个键头和值头。
GQA-G是指具有G组的分组查询。GQA-1具有单个组,因此具有单个键值头,相当于MQA,而GQA-H具有与头数量相等的组,相当于MHA。
图2显示了分组查询注意和多头/多查询注意的比较。当将多头checkpoint转换为GQA checkpoint时,我们通过平均池化该组中的所有原始头来构建每个组键和值头。
中间数量的组导致插值模型比MQA质量更高,但比MHA速度更快,并且,正如我们将展示的那样,代表了有利的权衡。从MHA到MQA将H键和值头减少到单个键和值头,减少了键-值缓存的大小,因此需要加载的数据量减少了H倍。然而,更大的模型通常会扩展头的数量,这样MQA在内存带宽和容量方面都代表了更大的削减。GQA允许我们在模型大小增加时保持相同比例的带宽和容量减少。
此外,由于KV-cache随模型尺寸的增大而增大,而模型FLOPs和参数随模型尺寸的平方而增大,较大的模型受到的注意力带来的内存带宽开销相对较小。最后,大型模型的standard sharding通过模型分区的数量复制单个键和值头(Pope et al., 2022);GQA消除了这种分区的浪费。因此,我们期望GQA能够为更大的模型提供一个特别好的权衡。
所有模型都基于T5.1.1架构(rafael等人,2020),由JAX (Bradbury等人,2018)、flex (Heek等人,2020)和Flaxformer1实现。对于我们的主要实验,我们考虑了具有多头注意力的T5 Large和XXL,以及具有多查询和分组查询注意力的T5 XXL的升级训练版本。我们将MQA和GQA应用于解码器自注意和交叉注意,但不应用于编码器自注意。
从公共T5.1.1检查点初始化经过升级训练的模型。将键头和值头平均池到适当的MQA或GQA结构中,然后使用原始预训练设置对原始预训练步骤进行进一步的α比例预训练(rafael et al., 2020)。α = 0.05
我们评估了CNN/Daily Mail (Nallapati等人,2016)、arXiv和PubMed (Cohan等人,2018)、MediaSum (Zhu等人,2021)和Multi-News (Fabbri等人,2019)的摘要数据集;翻译数据集WMT 2014;和问答数据集TriviaQA (Joshi et al., 2017)。我们没有对GLUE等流行的分类基准进行评估(Wang et al., 2019),因为自回归推理不太适用于这些任务。
对于微调,我们使用恒定的学习率为0.001,批大小为128,所有任务的dropout为0.1。CNN/Daily Mail和WMT使用的输入长度为512,输出长度为256。其他摘要数据集使用输入长度2048和输出长度512。最后,TriviaQA使用输入长度2048和输出长度32。我们训练直到收敛,并选择最高性能的checkpoint。我们使用贪婪解码进行推理。
下图提供实验来研究不同建模选择的影响。我们评估了具有代表性的任务子样本的性能:CNN/Daily Mail,(简短摘要),MultiNews(长格式摘要)和TriviaQA(问答)。
不同检查点转换方法对T5-Large上训练到比例为α = 0.05的MQA的性能比较。“Mean”是指将键和值头 做mean-pool,“First”选择第一个头,“Random”从头开始初始化头。
图6展示了GQA组的数量对推断速度的影响。对于较大的模型,KV缓存的内存带宽开销约束较少(Shazeer, 2019),而由于头数量的增加,键值大小的减少更为明显。因此,增加来自MQA的组的数量最初只会导致适度的减速,随着我们接近MHA而增加成本。我们选择了8组作为有利的中间立场。
GQA- xxl的每个样本时间作为输入长度2048和输出长度512的GQA组数量的函数。从1个(MQA)组增加到8个组会增加适度的推理开销,增加更多组会增加成本
为了创建新的Llama 2模型家族,我们从Touvron等人(2023)中描述的预训练方法开始,使用优化的自回归变压器,但进行了一些更改以提高性能。具体来说,我们执行了更健壮的数据清理,更新了我们的数据混合,训练了超过40%的总令牌,将上下文长度增加了一倍,并使用分组查询关注(GQA)来提高更大模型的推理可伸缩性。表1比较了新Llama 2模型与Llama 1模型的属性。
我们采用了Llama 1中的大部分预训练设置和模型架构。我们使用标准变压器架构(Vaswani等人,2017),使用RMSNorm应用预归一化(Zhang和Sennrich, 2019),使用SwiGLU激活函数(Shazeer, 2020)和旋转位置嵌入(RoPE, Su等人,2022)。与Llama 1的主要架构差异包括增加了上下文长度和分组查询关注(GQA)。我们在附录A.2.1节中通过烧蚀实验详细说明这些差异的重要性。
更大的模型——34B和70B——使用分组查询注意(GQA)来提高推理的可扩展性
自回归解码的标准做法是缓存序列中前面标记的键(K)和值(V)对,从而加快注意力计算。然而,随着上下文窗口或批处理大小的增加,多头注意(MHA)模型中与KV缓存大小相关的内存成本显著增加。对于较大的模型,KV缓存大小成为瓶颈,键和值预测可以在多个头之间共享,而不会导致性能下降(Chowdhery et al., 2022)。可以使用具有单个KV投影的原始多查询格式(MQA, Shazeer, 2019)或具有8 KV投影的分组查询关注变体(GQA, Ainslie等人,2023)。
在表18中,我们用MHA基线比较了MQA和GQA变体。我们用150B令牌训练所有模型,同时保持固定的30B模型大小。为了在GQA和MQA之间保持相似的总体参数计数,我们增加了前馈层的维度,以补偿注意层的减少。对于MQA变体,我们将FFN维度增加1.33倍,对于GQA变体,我们将其增加1.3倍。从结果中,我们观察到GQA变体在大多数评估任务上的表现与MHA基线相当,并且平均优于MQA变体。
为了优化延迟,我们在具有张量并行性的单个节点上使用8个a100来托管最大的模型(Shoeybi等人,2019)。在这种设置中,由于头的数量低于gpu的数量,因此不能再跨头进行MQA分片。要么在所有gpu中复制KV值(使KV缓存大小等于GQA),要么在批处理维度上进行分片(Pope et al., 2022)。然而,后者可能会使推理服务复杂化,因为它只有在批处理大小大于分片数量时才有效,并且在所有情况下,额外的通信成本都不值得。
因此,基于消融结果和缩放推理的易用性,对于34B和70B Llama 2模型,我们选择使用GQA而不是MQA。图24显示了与MHA基线相比,30B GQA和MQA消融模型的推理速度变化情况,使用8 x 80 GiB a100进行张量并行实验。在这些运行中,我们只是在所有gpu中复制MQA的KV头,因此MQA的KV缓存大小与GQA相等,并且这两个变体的行为非常相似(MQA只是具有稍微大一点的FFN维度)。
Multi-query variants可以在更大的批处理大小下实现更高的吞吐量,并且在较小的批处理上显示类似的延迟。输出长度固定为128个令牌。第一个数据点对应于批处理大小1,然后我们将其加倍,直到模型耗尽内存。对于256个令牌的上下文,MHA变体在批大小为1024时触发内存不足错误,而对于2k个上下文,批大小为128时触发内存不足错误,而MQA和GQA在这些设置中可以成功运行。
多查询注意(Multi-query attention, MQA)[15]和分组查询注意(group -query attention, GQA)[1]是注意的变体,其中多个查询头关注同一个键和值的头,以便在推理过程中减少KV缓存的大小。我们不必为计算复制键头和值头,而是隐式地操作头的索引来执行相同的计算。在逆向传递中,我们需要对不同头部的梯度dK和dV求和,这些头部是隐式重复的。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。