当前位置:   article > 正文

大语言模型框架-Megatron-LM源码分析_megatron-lm架构

megatron-lm架构

语言模型框架-Megatron-LM源码分析

原创 MLOps社区 DeepPrompting 2023年11月11日 11:44 北京

Megatron-LM是NVIDIA开源的大语言模型框架,是很多披露的大语言模型的训练使用的源头框架,很多公司基于其二次开发新的语言模型系统,例如Megatron-LM-DeepSpeed。

Megatron核心解决问题就是提供多种分布式切分并行策略,让大语言模型能够部署在多卡分布式环境下。本文将针对,张量并行,流水并行,数据并行的实现展开源码分析。MoE我们可以当成一种特殊的稀疏化结构,就不在本章进行介绍。

图片来自DeepSpeed(本文不介绍ZeRO,感兴趣读者可参考相关论文)

1 张量并行

张量并行分为行切和列切并行(指的是对输入矩阵切法),具体读者可以参考Megatron论文,其实现方式是继承实现Linear层,进而实现其中的并行策略,只需要替换模型中的Linear即可,后面我们也会看到MoE也是这种实现技巧。

图片

图来源Megatron-LM

图片

       world_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, world_size)通过这部分代码进行并行partition划分,worldsize是配置的tensorparallel的卡数,将完整input切成这么多份数,在每个执行这个代码的rank进行权重创建。

图片

如果输入不是并行切分好的,通过scatter去拿这部分权重对应的输入数据。

图片

scatter[图来源PyTorch官网]

图片

forward实现是配合异步allreduce进而将计算和comm通信并发执行。

图片

图片

    async_grad_allreduce (bool required): Do the allreduce of input
        gradients asyncronously with the computation of weight
        gradients. If sequence_parallel is True, this must be
        False, as no all reduce is performed.是对在BP阶段输入的gradients是否进行异步计算

这个linear前向是标准的torch matmul,除非sequence有并行设置才会进行一定的allgather通信。

图片

之前异步都被RowLinear配置false,所以BP核心是fused kernel。且可以使用低精度16bit的内核。  gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
        extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
        --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"
        ". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
        Defaults to False.


 

图片

 如果不选fuse kernel则执行执行矩阵乘完成BP反向传播计算。

图片

有意思的是Apex库也用的这个linear实现,库之前二次开发逐渐成为当前常用方式

图片

这个fusekernel最终还是调用的cublas的gemm

图片

FP16使用的BF16,且也是cublas gemm,并可以利用tensor core的加速。也有版本可以选择    at::Half* A,类型的bit 16类型内核。

图片

下面是上面的kernel优化目的Gradient accumulation fusion的介绍。

图片

所以C矩阵在kernel的输出是32 bit,输入是两个16bit矩阵保证了输出累计的数据精度。

图片

接下来我们再看列切分配置,其类似行切分。好处是省去下一步的GEMM之前的allreduce通信,所以attention和mlp的第一层gemm megatron选择列切,之后再行切下一个阶段gemm。

图片

其中gather_output由模型设计者决定,进而和后面的层进行配合,看输出是当前层是否需要聚合还是不需要聚合收集。这点是比较有意思,当前属于人工硬编码选择配置好策略。

图片

初始化决定列切切分维度,这个world size 通过tensor parallel的输入shell进行配置。

一般配置是8路,也就是张量并行在一个8卡的server,如果用户是4卡server则配置4即可。

图片

图片

和Row并行的区别是用户配置是否进行一次all gather output

图片

allgather原因是列切完,我们产生了两块局部结果。如下文所示是两个partition 1 和 2。

图片

图片

all gather [图来源PyTorch官网]

2 流水并行

其核心代码在p2p_communication中实现有4个重要参数。先沟通改下需要传递的tensor shape,"""Communicate tensor shapes between stages. Used to communicate
    tensor shapes before the actual tensor communication happens.
    This is required when the sequence lengths across micro batches
    are not uniform.

图片

下面是真正的传递核心,使用的P2POp,isend传递是异步传递,同时通过函数去确定相邻的rank,这样写这个函数就不用管拓扑了。

图片

然后将这些操作符做批量通信发送,一批完成和自己上游和下游的send recv异步通信

图片

调用在_communicate进行

图片

当前方式是相当于有了microbatch的发送方式,也就是既有下个microbatch的前向,也有当前batch的反向,一批次做异步通信。适合整体都已经运行起来了,已经不是第一个batch的场景。

图片

这是在backward中调用的上面的API

图片

流水并行的切分靠的是这个函数获取自己的这个rank到底是哪个model chunk,相当于静态切分好自己是属于哪块,类似张量并行。当前并行策略都是可以考虑这种分配方式,静态编译好partition和rank映射关系,启动后获取这个关系决定自己的通信方式。

图片

前向传播通过以下函数进行p2p通信。

图片

底层还是通过_communicate调用实现类似上面的bp过程。当前相当于只需要给下个rank send所以tensor_send_prev为空。如果多个microbatch配置是第一个batch前向,或者没配置多个microbatch场景使用。

图片

3 数据并行

数据并行一般有几点优化:将梯度通信和BP计算进行overlap,同时可以使用低精度做梯度聚合,将grad组成成小桶聚合。

图片

图片

图来自PyTorch,桶聚合相当于batching通信张量

通过hook注册梯度更新事件

图片

图片

图片

Overlap的核心是当bucket的 paramer累计到都有gradient,触发allreduce同步。BP计算该做自己的继续做,产生好的gradient,这部分只要ready就同时触发allreduce与进行的BP就无关了,但是没产生的gradient还不能进行计算。

图片

grad_buffer中是核心逻辑。

图片

图片

allreduce,图片来源PyTorch

核心逻辑通过GradBuffer进行聚合成连续buffer,再拆解。

图片

本质是allreduce并可以选择是否是异步。

图片

在finish中不同等待异步allreduce

图片

此处聚合wait所有的handler同步通信。

图片

    config.finalize_model_grads_func = finalize_model_grads
在没有流水并行下,执行刚才的同步通信等待。

图片

AI Infra96

AI Infra · 目录

上一篇大语言模型内核源码分析-4 Paging推理内核下一篇大语言模型内核源码分析-4IO-Awareness内核

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/868205
推荐阅读
相关标签
  

闽ICP备14008679号