赞
踩
感觉这篇paper有几个亮点,首先把Megatron-LM的Self-Attention模块的模型并行方式变成序列并行,优化了通信量,同时通过计算和通信重叠近一步压缩了训练迭代时间。另外,在使用重计算的时候发现当前Huggingface/Megatron-LM的重计算策略和FlashAttentionV2同时工作的话会导致Transformer Layer多计算一次Flash Attention的forward,然后修正了这个问题,获得了很直接的性能提升。paper的代码实现基于Triton并且不算长,后面尝试讲解这里的代码,应该会先从这里的DISTATTN开始。
从 https://github.com/RulinShao/LightSeq 注意到这篇paper(https://arxiv.org/pdf/2310.03294.pdf),paper里面有一些比较有趣的发现并且这个paper的代码是基于Triton来实现的,所以激发了我阅读兴趣。我后续也会从源码的角度来解读这篇paper核心idea的代码实现,顺便学习下Triton。介于篇幅原因,这篇文章只读一下这篇paper,把握一下核心的Infra相关的idea。这篇paper应该还没有中会议,处于openreview阶段。
从题目可以看出这是一个专注于提升LLM长文本训练长度的工作。
提高大型语言模型(LLMs)训练时的上下文长度可以解锁根本性的新能力,但也显著增加了训练的内存占用。Megatron-LM通过模型并行以及并行计算注意力头引入了大量的通信,所以在继续增大模型规模时会受限(在介绍的部分会详细说这里的受限原因)。这篇paper介绍了一种针对长上下文LLMs训练的新方法,LIGHTSEQ。LIGHTSEQ有许多显著的优点。首先,LIGHTSEQ在序列维度上进行切分,所以对模型架构是无感的,且可直接应用于具有不同数量注意力头的模型,如Multi-Head、Multi-Query和Grouped-Query注意力。其次,LIGHTSEQ不仅在流行的LLMs上比Megatron-LM减少了高达4.7倍的通信量,而且还实现了通信与计算的重叠。为了进一步减少训练时间,LIGHTSEQ采用了一种新的Activation Checkpointing方案,以绕过内存高效的自注意力实现的前向过程(指的应该就是FlashAttention)。我们在Llama-7B及其变体上评估了LIGHTSEQ,序列长度从32K到512K。通过在单节点和跨节点训练上的全面实验,我们展示了LIGHTSEQ达到了高达1.24-2.01倍的端到端加速,并且与Megatron-LM相比,LIGHTSEQ在具有更少注意力头的模型上实现了2-8倍更长的序列长度。代码开源在https://github.com/RulinShao/LightSeq。
感觉这里的介绍对理解paper的工作是有好处的,就精准翻译一下。
具有长上下文能力的 Transformer 已经使得一些全新的应用成为可能,例如全面的文档理解、生成完整的代码库以及扩展的互动聊天(Osika, 2023; Liu 等人, 2023; Li 等人, 2023)。然而,训练能处理长序列的大型语言模型(LLMs)会导致大量的Activation内存占用,给现有的分布式系统带来了新的挑战。减少这些大量Activation内存占用的一个有效方法是将Activation切分到不同的设备上。为了实现这一点,现有系统如 Megatron-LM(Korthikanti 等人, 2023; Shoeybi 等人, 2019)通常会切分注意力头。然而,这种设计强假设注意力头的数量必须能被并行度整除,这对许多模型架构来说并不成立。例如,Llama-33B 有 52 个注意力头,这个数量不能被 NVIDIA 集群的常选并行度,如 8、16 和 32 整除。此外,分割注意力头限制了最大并行度不能大于注意力头的数量。然而,许多受欢迎的大型语言模型并没有足够的注意力头来实现并行度扩展,例如 CodeGen模型(Nijkamp 等人, 2022)只有 16 个注意力头。更有甚者,许多研究表明未来的 Transformer 架构设计可能会有更少的注意力头。例如,Bian 等人(2021)展示了单头 Transformer 在性能上超越了多头对应的版本,这对像 Megatron-LM 这样的解决方案来说是一个挑战。为了解除注意力头数的限制,我们提出仅分割输入tokens(即序列并行),而不是注意力头。我们提出了一个与模型架构无关且具有最大并行度随序列长度而随之扩展的解决方案。 具体来说,我们引入了一个可并行化且内存高效的精确注意力机制,DISTATTN(§3.1)。我们的设计使得重叠成为可能,我们可以将通信隐藏进注意力计算中(§ 3.2)。我们还提出了一种负载平衡技术,以避免因工作负载不平衡而导致的在因果语言模型中的计算bubble(§3.2)。在将 FlashAttention(Dao, 2023)算法扩展到 DISTATTN 的过程中,我们找到了一种利用底层重新计算逻辑显著提高gradient checkpointing训练速度的方法(§ 3.3)。这项技术也适用于非分布式使用的内存高效注意力,在我们的实验中转化为额外的 1.31× 速度提升(§ 4.3)。
这里对于注意力头的切分描述我觉得很怪,一般Megatron不是按照TP大小来切分自注意力头吗,而TP大小一般不会超过8的。感觉这里说的TP 16,TP 32是很不常见的设置。
paper的贡献总结如下:
这里涉及到对内存高效的自注意力,序列并行,模型并行,FSDP,Gradient checkpointing等技术的简介,由于只是简要介绍,没有干货,这里就略过了。
这是paper最核心的部分,需要仔细理解。在本节中,我们描述了 LIGHTSEQ 中关键组件的设计。我们首先介绍了一种分布式内存高效注意力机制,DISTATTN(§3.1),它沿序列维度并行化计算。然后,我们引入了一种用于因果语言建模的负载平衡调度,以减少计算bubble,以及一种异步通信设计,将通信与计算重叠(§3.2)。最后,我们提出了一种rematerialization-aware checkpointing 策略(§3.3),有效地减少了在Gradient checkpointing中的重计算时间。
DISTATTN 的核心思想是将包含
N
N
N 个token的输入序列沿着序列维度均匀分割到
P
P
P 个 worker(例如 GPU)上。因此,每个 worker 只负责计算
N
/
P
N/P
N/P 个 token 的前向传递和后向传递。对于像前馈层(FFN)、层标准化(LN)和 Embedding 层这样的模块,token 可以独立计算,无需协调,并且工作在 worker 之间平衡。不幸的是,对于自注意力模块,其中本地 token 可能需要关注远程 token,需要协调。为了解决这个问题,每个 worker 需要收集(gather)与其它 token 关联的所有 key 和 value。为了应对通过收集所有其它 key 和 value 引入的内存压力,这个过程通过在线流式传输,即从拥有靠前 tokens 的 workers 向拥有靠后 tokens 的 workers 传输 key 和 value 来完成。更正式地,用
q
p
q_p
qp、
k
p
k_p
kp、
v
p
v_p
vp 表示持有在
p
p
p (
p
∈
1
,
.
.
.
,
P
p\in {1, ..., P}
p∈1,...,P)个worker上的query、key、value输入,用
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。