当前位置:   article > 正文

AI Infra论文阅读之LIGHTSEQ(LLM长文本训练的Infra工作)_ai训练 tp 16 tp 128

ai训练 tp 16 tp 128

感觉这篇paper有几个亮点,首先把Megatron-LM的Self-Attention模块的模型并行方式变成序列并行,优化了通信量,同时通过计算和通信重叠近一步压缩了训练迭代时间。另外,在使用重计算的时候发现当前Huggingface/Megatron-LM的重计算策略和FlashAttentionV2同时工作的话会导致Transformer Layer多计算一次Flash Attention的forward,然后修正了这个问题,获得了很直接的性能提升。paper的代码实现基于Triton并且不算长,后面尝试讲解这里的代码,应该会先从这里的DISTATTN开始。

0x0. 前言

从 https://github.com/RulinShao/LightSeq 注意到这篇paper(https://arxiv.org/pdf/2310.03294.pdf),paper里面有一些比较有趣的发现并且这个paper的代码是基于Triton来实现的,所以激发了我阅读兴趣。我后续也会从源码的角度来解读这篇paper核心idea的代码实现,顺便学习下Triton。介于篇幅原因,这篇文章只读一下这篇paper,把握一下核心的Infra相关的idea。这篇paper应该还没有中会议,处于openreview阶段。

在这里插入图片描述
从题目可以看出这是一个专注于提升LLM长文本训练长度的工作。

0x1. 摘要

在这里插入图片描述提高大型语言模型(LLMs)训练时的上下文长度可以解锁根本性的新能力,但也显著增加了训练的内存占用。Megatron-LM通过模型并行以及并行计算注意力头引入了大量的通信,所以在继续增大模型规模时会受限(在介绍的部分会详细说这里的受限原因)。这篇paper介绍了一种针对长上下文LLMs训练的新方法,LIGHTSEQ。LIGHTSEQ有许多显著的优点。首先,LIGHTSEQ在序列维度上进行切分,所以对模型架构是无感的,且可直接应用于具有不同数量注意力头的模型,如Multi-Head、Multi-Query和Grouped-Query注意力。其次,LIGHTSEQ不仅在流行的LLMs上比Megatron-LM减少了高达4.7倍的通信量,而且还实现了通信与计算的重叠。为了进一步减少训练时间,LIGHTSEQ采用了一种新的Activation Checkpointing方案,以绕过内存高效的自注意力实现的前向过程(指的应该就是FlashAttention)。我们在Llama-7B及其变体上评估了LIGHTSEQ,序列长度从32K到512K。通过在单节点和跨节点训练上的全面实验,我们展示了LIGHTSEQ达到了高达1.24-2.01倍的端到端加速,并且与Megatron-LM相比,LIGHTSEQ在具有更少注意力头的模型上实现了2-8倍更长的序列长度。代码开源在https://github.com/RulinShao/LightSeq。

0x2. 介绍

感觉这里的介绍对理解paper的工作是有好处的,就精准翻译一下。

具有长上下文能力的 Transformer 已经使得一些全新的应用成为可能,例如全面的文档理解、生成完整的代码库以及扩展的互动聊天(Osika, 2023; Liu 等人, 2023; Li 等人, 2023)。然而,训练能处理长序列的大型语言模型(LLMs)会导致大量的Activation内存占用,给现有的分布式系统带来了新的挑战。减少这些大量Activation内存占用的一个有效方法是将Activation切分到不同的设备上。为了实现这一点,现有系统如 Megatron-LM(Korthikanti 等人, 2023; Shoeybi 等人, 2019)通常会切分注意力头。然而,这种设计强假设注意力头的数量必须能被并行度整除,这对许多模型架构来说并不成立。例如,Llama-33B 有 52 个注意力头,这个数量不能被 NVIDIA 集群的常选并行度,如 8、16 和 32 整除。此外,分割注意力头限制了最大并行度不能大于注意力头的数量。然而,许多受欢迎的大型语言模型并没有足够的注意力头来实现并行度扩展,例如 CodeGen模型(Nijkamp 等人, 2022)只有 16 个注意力头。更有甚者,许多研究表明未来的 Transformer 架构设计可能会有更少的注意力头。例如,Bian 等人(2021)展示了单头 Transformer 在性能上超越了多头对应的版本,这对像 Megatron-LM 这样的解决方案来说是一个挑战。为了解除注意力头数的限制,我们提出仅分割输入tokens(即序列并行),而不是注意力头。我们提出了一个与模型架构无关且具有最大并行度随序列长度而随之扩展的解决方案。 具体来说,我们引入了一个可并行化且内存高效的精确注意力机制,DISTATTN(§3.1)。我们的设计使得重叠成为可能,我们可以将通信隐藏进注意力计算中(§ 3.2)。我们还提出了一种负载平衡技术,以避免因工作负载不平衡而导致的在因果语言模型中的计算bubble(§3.2)。在将 FlashAttention(Dao, 2023)算法扩展到 DISTATTN 的过程中,我们找到了一种利用底层重新计算逻辑显著提高gradient checkpointing训练速度的方法(§ 3.3)。这项技术也适用于非分布式使用的内存高效注意力,在我们的实验中转化为额外的 1.31× 速度提升(§ 4.3)

这里对于注意力头的切分描述我觉得很怪,一般Megatron不是按照TP大小来切分自注意力头吗,而TP大小一般不会超过8的。感觉这里说的TP 16,TP 32是很不常见的设置。

paper的贡献总结如下:

  • 我们设计了 LIGHTSEQ,这是一个基于序列级并行的长上下文大型语言模型(LLM)训练原型。我们开发了一种分布式内存高效精确注意力机制 DISTATTN,采用了新的负载平衡和用于因果语言模型的计算和通信重叠调度。
  • 我们提出了一种新的检查点策略,当使用内存高效注意力与gradient checkpointing训练时,可以绕过一个注意力前向传播。
  • 我们在 Llama-7B 及其不同注意力头模式的变体上评估了 LIGHTSEQ,并展示了与 Megatron-LM 相比,在长上下文训练中高达 2.01× 的端到端加速。我们进一步展示了 LIGHTSEQ 能够超越注意力头的数量,实现 2-8× 更长序列的训练。

0x3. 相关工作

这里涉及到对内存高效的自注意力,序列并行,模型并行,FSDP,Gradient checkpointing等技术的简介,由于只是简要介绍,没有干货,这里就略过了。

0x4. 方法

这是paper最核心的部分,需要仔细理解。在本节中,我们描述了 LIGHTSEQ 中关键组件的设计。我们首先介绍了一种分布式内存高效注意力机制,DISTATTN(§3.1),它沿序列维度并行化计算。然后,我们引入了一种用于因果语言建模的负载平衡调度,以减少计算bubble,以及一种异步通信设计,将通信与计算重叠(§3.2)。最后,我们提出了一种rematerialization-aware checkpointing 策略(§3.3),有效地减少了在Gradient checkpointing中的重计算时间。

0x4.1 分布式高效自注意力计算

在这里插入图片描述
DISTATTN 的核心思想是将包含 N N N 个token的输入序列沿着序列维度均匀分割到 P P P 个 worker(例如 GPU)上。因此,每个 worker 只负责计算 N / P N/P N/P 个 token 的前向传递和后向传递。对于像前馈层(FFN)、层标准化(LN)和 Embedding 层这样的模块,token 可以独立计算,无需协调,并且工作在 worker 之间平衡。不幸的是,对于自注意力模块,其中本地 token 可能需要关注远程 token,需要协调。为了解决这个问题,每个 worker 需要收集(gather)与其它 token 关联的所有 key 和 value。为了应对通过收集所有其它 key 和 value 引入的内存压力,这个过程通过在线流式传输,即从拥有靠前 tokens 的 workers 向拥有靠后 tokens 的 workers 传输 key 和 value 来完成。更正式地,用 q p q_p qp k p k_p kp v p v_p vp 表示持有在 p p p p ∈ 1 , . . . , P p\in {1, ..., P} p1,...,P)个worker上的query、key、value输入,用

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/759457
推荐阅读
相关标签