当前位置:   article > 正文

[论文翻译]GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

gqa: training generalized multi-query transformer models from multi-head che

引言

今天带来论文GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints的笔记。

多查询注意力(Multi-query attention,MQA)1仅使用单个key-value头,极大地加速了解码器的推理。但同时也带来了性能降级,而且为了更快的推理而训练一个单独的模型可能并不理想。作者(1)提出了一种将现有的多头语言模型检查点提升为使用仅占原始预训练计算5%的MQA模型的方法,以及(2)引入了Grouped-query attention(GQA,分组查询注意力),这是多查询注意力的一种泛化形式,使用折中数量的key-value头(超过一个,但少于多头注意力全部的头数量)。

作者展示了经过提升的GQA模型在质量上接近于多头注意力,并具有与MQA相当的速度。

总体介绍

由于每个解码步骤中加载解码器权重以及所有注意力key和value所带来的内存带宽开销,自回归解码器推理对Transformer模型来说是一个严重的瓶颈。通过多查询注意力1,加载键和值的内存带宽可以大大减少,它使用多个查询头,但是只有单个键头和值头。

然而,多查询注意力可能导致质量下降和训练不稳定,并且训练适用于质量和推理的单独模型可能不可行。此外,尽管一些语言模型已经使用了多查询注意力,例如PaLM,但许多语言模型尚未使用,包括公开可用的语言模型如T5和LLaMA。

本工作对于更快的大型语言模型推理提供了两个贡献。首先,展示了多头注意力(Multi-head attention,MHA)的语言模型检查点可以通过使用一小部分原始训练计算来进行提升,以使用多查询注意力。这提供了一种经济高效的方法来获得快速的多查询和高质量的MHA检查点。

其次,提出了分组查询注意力(GQA),它是多头和多查询注意力之间的插值(interpolation),每个查询头子组使用单个键头和值头。

方法

提升

image-20240413222445128

将多头模型转换为多查询模型分为两个步骤:首先是转换检查点,然后是额外的预训练,以使模型适应其新的结构。图1展示了将多头检查点转换为多查询检查点的过程。键头和值头的投影矩阵被平均池化成单个投影矩阵,作者发现这比选择单个键头和值头或从头开始随机初始化新的键头和值头效果更好。

然后,转换后的检查点会按照与原始预训练相同的方式进行预训练,但仅进行小部分α比例的训练步骤。

分组查询注意力

image-20240413222803653

分组查询注意力将查询头分为 G G G组,每组共享一个键头和值头。GQA-G表示具有 G G G​组的分组查询注意力。GQA-1表示具有单个组,因此只有一个键头和一个值头,等同于MQA;而GQA-H表示组数等于头数,等同于MHA。

图2展示了分组查询注意力和多头/多查询注意力的比较。在将多头检查点转换为GQA检查点时,作者通过对该组中的所有原始头进行平均池化来构建每个组的键头和值头。

中间组数导致插值模型的质量高于MQA,但快于MHA,这代表了一个有利的权衡。从MHA转换为MQA将 H H H个键头和值头减少为单个键头和值头,将键值缓存的大小减小,并因此将需要加载的数据量减少了 H H H倍。然而,较大的模型通常会扩展头的数量,因此多查询注意力在内存带宽和容量方面的削减更加激进。GQA使我们能够在模型大小增加时保持相同的带宽和容量的比例减小。

此外,较大的模型相对较少受到来自注意力的内存带宽开销的影响,因为KV缓存与模型维度成比例增长,而模型的FLOPs和参数与模型维度的平方成比例增长。最后,大型模型的标准分片通过模型分区的数量复制单个键头和值头;GQA消除了这种分区带来的浪费。因此,我们期望GQA对于较大的模型具有特别好的权衡。

作者注意到GQA不适用于编码器自注意层;编码器的表示是并行计算的,因此内存带宽通常不是主要的瓶颈。

实验

实验设定

配置 所有的模型都是基于T5.1.1架构实现的,使用JAX、Flax和Flaxformer。对于主要实验,考虑了带有多头注意力的T5 Large和XXL,以及带有多查询和分组查询注意力的T5 XXL的重新训练版本。使用Adafactor优化器,并使用与T5相同的超参数和学习率调度。将MQA和GQA应用于解码器的自注意力和交叉注意力,但不应用于编码器的自注意力。

提升 提升(Uptraining)指的是从公共T5.1.1检查点开始初始化重新训练的模型。键头和值头被平均池化到适当的MQA或GQA结构中,然后按照来自T5的原始预训练设置和数据集进行进一步的预训练,进行原始预训练步骤的α比例。对于 α = 0.05 α = 0.05 α=0.05​,训练大约需要600个TPUv3芯片天的时间。

数据 对摘要数据集进行评估,包括CNN/Daily Mail、arXiv和PubMed、MediaSum和Multi-News;翻译数据集WMT 2014英德翻译;以及问答数据集TriviaQA。不在常见的分类基准测试中进行评估,如GLUE,因为自回归推断对于这些任务的适用性较低。

微调 对于微调,对所有任务使用恒定的学习率0.001,批大小128,丢弃率0.1。CNN/Daily Mail和WMT使用512的输入长度和256的输出长度。其他摘要数据集使用2048的输入长度和512的输出长度。最后,TriviaQA使用2048的输入长度和32的输出长度。训练直到收敛,并选择开发集性能最好的检查点。使用贪婪解码进行推断。

耗时 作者报告每个TPUv4芯片每个样本的时间,这是通过xprof进行测量得出的。对于计时实验,使用8个TPU,批大小为32,每个TPU最大容纳的批大小,且并行化针对每个模型分别进行优化。

主要结果

image-20240413224615232

图3显示了MHA T5-Large和T5-XXL以及经过上训练的MQA和GQA-8 XXL模型在平均推断时间方面的所有数据集的平均性能。我们可以看到,相对于MHA模型而言,更大的上训练MQA模型提供了一个有利的权衡,具有比MHA-Large更高的质量和更快的推断速度。此外,GQA实现了显著的额外质量提升,其性能接近MHA-XXL,而速度接近MQA。表1包含了所有数据集的完整结果。

image-20240413224650718

消融

本节介绍了对不同建模选择的影响进行实验的结果。对一部分代表性任务进行性能评估,包括CNN/Daily Mail(短篇摘要)、MultiNews(长篇摘要)和TriviaQA(问答)。

image-20240413224821763

检查点转换 图4比较了不同方法在检查点转换方面的性能。平均池化似乎效果最好,其次是选择一个单独的头部,然后是随机初始化。直观来看,结果按照预训练模型中信息保存的程度进行排序。

image-20240413224911927

提升步骤 图5展示了T5 XXL模型在MQA和GQA上训练比例变化时的性能变化情况。首先,我们注意到GQA在转换后已经达到了合理的性能,而MQA则需要上训练才能发挥作用。无论是MQA还是GQA,都在5%的提升中获益,而从10%开始,收益递减。

image-20240413225045609

分组数量 图6展示了GQA组数对推理速度的影响。对于较大的模型来说,来自KV缓存的内存带宽开销较小,而由于头部数量增加,键值大小的减少更为明显。因此,从MQA增加组数只会导致初始时的适度减速,但随着接近MHA,成本随之增加。作者选择了8个组作为一个较为合适的折中方案。

相关工作

本篇工作的重点是通过减少加载键和值所产生的内存带宽开销,在解码器质量和推理时间之间实现更好的权衡。Shazeer1首次提出通过多查询注意力来减少这种开销。后续的研究表明,多查询注意力对于长输入特别有帮助。Rabe独立开发了具有公共实现的GQA。其他工作探索了将注意力头部进行分组以提高计算效率,但并没有特别关注决定内存带宽开销的键-值头部。

还有许多其他方法被提出来减少来自键和值以及参数的内存带宽开销。Flash Attention通过结构化注意力计算来避免物化(materializing)二次注意力分数,从而减少内存使用并加快训练速度。量化通过降低精度来减小权重和激活的大小,包括键和值。模型蒸馏通过使用从较大模型生成的数据对较小模型进行微调,而不是在给定精度下减小模型大小。Layer-sparse Cross-Attention消除了大部分组成较长输入的交叉注意力层,这是主要的开销。推测采样通过用较小模型提出多个标记,然后由较大模型并行评分来改善内存带宽瓶颈。

最后,作者提出的提升过程受到Komatsuzaki等的启发,他们将标准T5检查点上训练为稀疏激活的专家混合模型。

结论

作者提出将多头注意力模型转换为多查询模型,只需使用原始预训练计算的一小部分。此外,引入了分组查询注意力,它是多查询和多头注意力的插值,既能在与多查询注意力相当的速度下达到接近多头注意力的质量。

总结

⭐ 本篇工作改进了多查询注意力(Multi-query attention,MQA),在query数量不变的情况下,MQA仅使用单个key-value头,这能极大地加速解码器的推理,但也带来了性能下降。作者提出的分组查询注意力,简单来说就是增加了key-value头的数量,但显著少于MHA中的注意力头数。属于MQA和MHA中间的一个权衡,经过实验证明取得了不错的效果。

引用

A 训练稳定性

作者发现,多查询注意力在微调过程中可能导致训练不稳定,特别是在处理长输入任务时。作者从头开始训练了多个T5-Large模型,使用了多查询注意力。在每种情况下,预训练过程中经常出现损失峰值,并且在对长输入任务进行微调时,最终模型立即发散。经过上训练的多查询注意力模型更加稳定,但仍然显示出较高的差异性,因此对于不稳定任务的多查询模型,作者报告了三次微调运行的平均性能。然而,经过提升的分组查询注意力模型似乎是稳定的,因此作者没有进一步调查多查询不稳定性的根本原因。


  1. Fast Transformer Decoding: One Write-Head is All You Need ↩︎ ↩︎ ↩︎

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/936062
推荐阅读
相关标签
  

闽ICP备14008679号