当前位置:   article > 正文

深度学习框架 —— 分布式训练

分布式训练

现在深度学习的模型结构越来越大,参数动不动都是上亿甚至上千亿,这也对训练模型的资源量有很高的要求,显然单个机器上要训练这么大的网络是不现实的,因此学术界和工业界自然开始研究用分布式训练。也就是将一个机器学习模型任务拆分成多个子任务,并将子任务分发给多个计算节点,解决资源瓶颈。

1. 分布式训练概述 / 动机

分布式训练的动机很简答:单节点算力和内存不足,因此不得不做分布式训练。

训练机器学习模型需要大量内存。假设一个大型神经网络模型具有 1000 亿的参数(LLM 时代有不少比这个参数量更大的模型),每个参数都由一个 32 位浮点数(4 个字节)表达,存储模型参数就需要 400GB 的内存。在实际中,我们需要更多内存来存储激活值和梯度。假设激活值和梯度也用 32 位浮点数表达,那么其各自至少需要 400GB 内存,总的内存需求就会超过 1200GB(即 1.2TB)。而如今的硬件加速卡(如 NVIDIA A100)仅能提供最高80GB的内存。单卡内存空间的增长受到硬件规格、散热和成本等诸多因素的影响,难以进一步快速增长。因此,我们需要分布式训练系统来同时使用数百个训练加速卡,从而为千亿级别的模型提供所需的TB级别的内存。

为了方便获得大量用于分布式训练的服务器,我们往往依靠云计算数据中心。一个数据中心管理着数百个集群,每个集群可能有几百到数千个服务器。通过申请其中的数十台服务器,这些服务器进一步通过分布式训练系统进行管理,并行完成机器学习模型的训练任务。

相比单节点执行,可以看到分布式训练主要要对输入做一些切分(可能是切分数据集、切分模型参数或者混合的方式)并送到不同的节点进行训练,由于数据或参数的独立性,各节点之间的计算是相对独立的,因此可以并行计算,从而达到加速模型训练的目的。

为了确保分布式训练系统的高效运行,需要首先估计系统计算任务的计算和内存用量。假如某个任务成为了瓶颈,系统会切分输入数据,从而将一个任务拆分成多个子任务。子任务进一步分发给多个计算节点并行完成。

一个模型训练任务(Model Training Job)往往会有一组数据(如训练样本)或者任务(如算子)作为输入,利用一个计算节点(如GPU)生成一组输出(如梯度)。

分布式执行一般具有三个步骤:

  • 第一步将输入进行切分;
  • 第二步将每个输入部分会分发给不同的计算节点,实现并行计算;
  • 第三步将每个计算节点的输出进行合并,最终得到和单节点等价的计算结果。

这种首先切分,然后并行,最后合并的模式,本质上实现了分而治之(Divide-and-Conquer)的方法(实际上在计算机的系统领域这是一个非常主流的解决思路):由于每个计算节点只需要负责更小的子任务,因此其可以更快速地完成计算,最终实现对整个计算过程的加速。

2. 分布式训练主流实现方法

分布式训练系统的设计目标是:将单节点训练系统转换成 等价的 并行训练系统(divide and conquer),从而在不影响模型精度的条件下完成训练过程的加速。

下图是一个单节点训练系统的流程:为了更新参数(模型训练一个 iter 需要更新模型参数),计算图的执行分为前向计算和反向计算两个阶段。前向计算的第一步会将数据读入第一个算子,该算子会根据当前的参数,计算出计算给下一个算子的数据。算子依次重复这个前向计算的过程(执行顺序:算子1,算子2,算子3),直到最后一个算子结束。最后的算子随之马上开始反向计算。反向计算中,每个算子依次计算出梯度(执行顺序:梯度3,梯度2,梯度1),并利用梯度更新本地的参数。反向计算最终在第一个算子结束。反向计算的结束也标志本次数据小批次的结束,系统随之读取下一个数据小批次,继续更新模型。

再进一步对这个过程拆分:模型训练过程中有频繁的读数据计算这两个操作,因此我们可以从这两个方面入手,对数据和计算分别拆分,从而完成并行加速。

下表是分布式训练方法的分类:

分类单数据多数据
单程序单程序单数据:单点执行单程序多数据:数据并行
多程序多程序单数据:模型并行多程序多数据:混合并行

单节点训练系统可以被归类于单程序单数据模式。而假如用户希望使用更多的设备实现并行计算,首先可以选择对数据进行分区,并将同一个程序复制到多个设备上并行执行。这种方式是单程序多数据模式,常被称为数据并行(Data Parallelism)。另一种并行方式是对程序进行分区(模型中的算子会被分发给多个设备分别完成)。这种模式是多程序单数据模式,常被称为模型并行(Model Parallelism)。当训练超大型智能模型时,开发人员往往要同时对数据和程序进行切分,从而实现最高程度的并行。这种模式是多程序多数据模式,常被称为混合并行(Hybrid Parallelism)

2.1 数据并行

数据并行常见的应用有:PyTorch 和 MegEngine 的 Distributed,也就是起多机进行训练,主要是解决单机算力不足的问题。

在一个数据并行系统中,假设用户给定一个训练批大小为 N,并且希望使用 M 个并行设备来加速训练。那么,该训练批大小会被分为 M 个分区,每个设备会分配到 N / M 个训练样本。这些设备共享一个训练程序的副本,在不同数据分区上独立执行、计算梯度。不同的设备(假设设备编号为 i)会根据本地的训练样本计算出梯度 Gi. 为了确保训练程序参数的一致性,本地梯度 Gi 需要聚合(reduce,各个进程需要和主进程通信),计算出平均梯度。最终,训练程序利用平均梯度修正模型参数,完成小批次的训练。

下图展示了两个设备构成的数据并行训练系统(Data Parallel Training System)的例子。假设用户给定的数据批大小是 64,那么每个设备会分配到 32 个训练样本,并且具有相同的神经网络参数(程序副本)。本地的训练样本会依次通过这个程序副本中的算子,完成前向计算和反向计算。在反向计算的过程中,程序副本会生成局部梯度。不同设备上对应的局部梯度(如设备 1 和设备 2 上各自的梯度1)会进行聚合,从而计算平均梯度。这个聚合的过程往往由集合通信的 AllReduce 操作完成(用 cuda 的话一般是通过 NCCL 来完成)。

2.2 模型并行

**模型并行往往用于解决单节点内存不足的问题。**一个常见的内存不足场景是模型中含有大型算子,例如深度神经网络中需要计算大量分类的全连接层。完成这种大型算子计算所需的内存可能超过单设备的内存容量。那么需要对这个大型算子进行切分。假设这个算子具有 P 个参数,而系统拥有 N 个设备,那么可以将 P 个参数平均分配给 N 个设备,从而让每个设备负责更少的计算量,能够在内存容量的限制下完成前向计算和反向计算。这种切分方式是模型并行训练系统(Model Parallelism Training System)的一种应用,也被称为 算子内并行 (Intra-operator Parallelism)。

下图是一个模型并行的流程图,同样的一份数据被广播成两份给两个设备分别计算,两个设备的计算并不相同,分别计算出结果之后再 Gather 汇总结果(到主进程)。

在这个例子中,假设一个神经网络具有两个算子,算子 1 的计算(包含正向和反向计算)需要预留 16 GB的内存,算子 2 的计算需要预留 1GB 的内存。而本例中的设备最多可以提供 10GB 的内存。为了完成这个神经网络的训练,需要对算子 1 实现并行。具体做法是,将算子 1 的参数平均分区,设备 1 和设备 2 各负责其中部分算子1的参数。由于设备 1 和设备 2 的参数不同,因此它们各自负责程序分区 1 和程序分区 2。在训练这个神经网络的过程中,训练数据(按照一个小批次的数量)会首先传给算子 1。由于算子 1 的参数分别由两个设备负责,因此数据会被广播(Broadcast)给这两个设备。不同设备根据本地的参数分区完成前向计算,生成的本地计算结果需要进一步合并,发送给下游的算子 2。在反向计算中,算子 2 的数据会被广播给设备 1 和设备 2,这些设备根据本地的算子 1 分区各自完成局部的反向计算。计算结果进一步合并计算回数据,最终完成反向计算。

另一种内存不足的场景是:模型的总内存需求超过了单设备的内存容量。在这种场景下,假设总共有 N 个算子和 M 个设备,可以将算子平摊给这 M 个设备,让每个设备仅需负责 N / M 个算子的前向和反向计算,降低设备的内存开销。这种并行方式是模型并行的另一种应用,被称为算子间并行 (Inter-operator Parallelism)。

下图给出了一个由两个设备实现的算子间并行的例子。在这个例子中,假设一个神经网络具有两个算子,算子 1 和算子 2 各自需要 10GB 的内存完成计算,则模型总共需要 20GB 的内存。而每个设备仅能提供 10GB 内存。在这个例子中,用户可以把算子 1 放置在设备 1 上,算子 2 放置在设备 2 上。在前向计算中,算子 1 的输出会被发送给下游的设备 2。设备 2 接收来自上游的数据,完成算子 2 的前向计算。在反向计算中,设备 2 将算子 2 的反向计算结果发送给设备 1。设备 1 完成算子 1 的反向计算,完成本次小批次(Mini-Batch)的训练。

简单来说,并行的思路就是:各算各的,算完再汇总。

2.3 混合并行

在训练大型人工智能模型中,往往会同时面对算力不足和内存不足的问题。因此,需要混合使用数据并行和模型并行,这种方法被称为混合并行

下图就是一个混合并行的例子,数据集被切分到不同的机器上执行,同样的数据集又会被切分到不同的设备上执行不同的计算。这里提供了一个由 4 个设备实现的混合并行的例子。在这个例子中,首先实现算子间并行解决训练程序内存开销过大的问题:该训练程序的算子 1 和算子 2 被分摊到了设备 1 和设备 2 上。进一步,通过数据并行添加设备 3 和设备 4,提升系统算力。为了达到这一点,对训练数据进行分区(数据分区 1 和数据分区 2),并将模型(算子 1 和算子 2,这里不一定是单个算子,可以是对计算图做拆分)分别复制到设备 3 和设备 4。在前向计算的过程中,设备 1 和设备 3 上的算子 1 副本同时开始,计算结果分别发送给设备 2 和设备 4 完成算子 2 副本的计算。在反向计算中,设备 2 和设备 4 同时开始计算梯度,本地梯度通过 AllReduce 操作进行平均。反向计算传递到设备 1 和设备 3 上的算子 1 副本结束。

2.4 流水线并行

还有一种常用的实现分布式训练的方法谁流水线并行,这种系统通过算子内并行和算子间并行解决单设备内存不足的问题。

然而,这类系统的运行中,计算图中的下游设备(Downstream Device)需要长期持续处于空闲状态,等待上游设备(Upstream Device)的计算完成,才可以开始计算,这极大降低了设备的平均使用率。这种现象称为模型并行气泡(Model Parallelism Bubble)。

下图是一个示意图,可以看出来各个设备之间有一定的计算执行次序,但是存在大量的空闲时间:

为了减少气泡,通常可以在训练系统中构建流水线。这种做法是将训练数据中的每一个小批次划分为多个微批次(Micro-Batch)。假设一个小批次有 D 个训练样本,将其划分为 M 个微批次,那么一个微批次就有 D / M 个数据样本。每个微批次依次进入训练系统,完成前向计算和反向计算,计算出梯度。每个微批次对应的梯度将会缓存,等到全部微批次完成,缓存的梯度会被加和,算出平均梯度(等同于整个小批次的梯度),完成模型参数的更新。(感觉流水线并行本质还是混合并行?

本例中,模型参数需要切分给 4 个设备存储。为了充分利用这 4 个设备,将小批次切分为两个微批次。假设 Fi,j 表示第 j 个微批次的第 i 个前向计算任务,Bi, j 表示第 j 个微批次的第 i 个反向计算任务。当设备 1 完成第一个微批次的前向计算后(表示为 F0,0),会将中间结果发送给设备 2,触发相应的前向计算任务(表示为F1,0)。与此同时,设备1也可以开始第二个微批次的前向计算任务(表示为 F0,1)。前向计算会在流水线的最后一个设备,即设备3,完成。

系统于是开始反向计算。设备 4 开始第 1 个微批次的反向计算任务(表示为 B3,0)。该任务完成后的中间结果会被发送给设备 3,触发相应的反向计算任务(表示为 B2,0)。与此同时,设备 4 会缓存对应第 1 个微批次的梯度,接下来开始第 2 个微批次计算(表示为 B3,1)。当设备 4 完成了全部的反向计算后,会将本地缓存的梯度进行相加(这里设备 4 相当于主进程,reduce 的操作由它汇总),并且除以微批次数量,计算出平均梯度,该梯度用于更新模型参数。

需要注意的是,计算梯度往往需要前向计算中产生的激活值。经典模型并行系统中会将激活值缓存在内存中,反向计算时就可以直接使用,避免重复计算。而在流水线训练系统中,由于内存资源紧张,前向计算中的激活值往往不会缓存,而是在反向计算中重新计算(Recomputation),也就是用计算换内存

在使用流水线训练系统中,时常需要调试微批次的大小,从而达到最优的系统性能。当设备完成前向计算后,必须等到全部反向计算开始,在此期间设备会处于空闲状态。

可以看到上图中设备 1 在完成两个前向计算任务后,要等很长时间才能开始两个反向计算任务(等到其他设备前向和反向都计算完了才轮到它计算反向)。这其中的等待时间即被称为流水线气泡(Pipeline Bubble)。

为了减少设备的等待时间,一种常见的做法是尽可能地增加微批次的数量,从而让反向计算尽可能早开始。然而,使用非常小的微批次,可能会造成微批次中的训练样本不足,从而无法充分的利用起来硬件加速器中的海量计算核心。因此最优的微批次数量由多种因素(如流水线深度、微批次大小和加速器计算核心数量等)共同决定。

3. 集合通信

作为并行计算的一个重要概念,集合通信经常被用来构建高性能的单程序流/多数据流(Single Program-Multiple Data, SPMD)程序。

3.1 常见集合通信算子

3.1.1 通信模型

假定在一个分布式机器学习集群中,存在 p 个计算设备,并由一个网络来连接所有的设备。每个设备有自己的独立内存,并且所有设备间的通信都通过该网络传输。同时,每个设备都有一个编号 i,其中 i 的范围从 1 到 p。 设备之间的点对点(Point-to-Point, P2P)通信由全双工传输(Full-Duplex Transmission)实现。该通信模型的基本行为可以定义如下:

  • 每次通信有且仅有一个发送者(Sender)和一个接收者(Receiver)。在某个特定时刻,每个设备仅能至多发送或接收一个消息(Message)。每个设备可以同时发送一个消息和接收一个消息。一个网络中可以同时传输多个来自于不同设备的消息。
  • 传输一个长度为 l 个字节(Byte)的消息会花费 a+b×l 的时间,其中 a 代表延迟(Latency),即一个字节通过网络从一个设备出发到达另一个设备所需的时间;b 代表传输延迟(Transmission Delay),即传输一个具有 l 个字节的消息所需的全部时间。前者取决于两个设备间的物理距离(如跨设备、跨机器、跨集群等),后者取决于通信网络的带宽。需要注意的是,这里简化了传输延迟的定义,其并不考虑在真实网络传输中会出现的丢失的消息(Dropped Message)和损坏的消息(Corrupted Message)的情况。

3.1.2 Broadcast

一个分布式机器学习系统经常需要将一个设备 i 上的模型参数或者配置文件广播(Broadcast)给其余全部设备。

因此,可以把Broadcast算子定义为从编号为 i 的设备发送长度为 l 字节的消息给剩余的 p−1 个设备。下图的左上半部分展示了设备 1(在一个三设备的集群里)调用 Broadcast 的初始和结束状态:

一种简单实现 Broadcast 的算法是在设备 i 上实现一个循环,该循环使用 p−1 次 Send / Receive 操作来将数据传输给相应设备。然而,该算法不能达到并行通信的目的(该算法只有 (a+b×l)×(p−1)

的线性时间复杂度)。为此,可以利用分治思想对上述简单实现的Broadcast算法进行优化。假设所有的设备可以重新对编号进行排列,使得 Broadcast 的发送者为编号为 1 的设备。同时,为了简化计算过程,假设对某个自然数 n,p=2^n。 现在,可以通过从 1 向 p/2 发送一次信息把问题转换为两个大小为 p/2 的子问题:编号为 1 的设备对编号 1 到编号 p/2−1 的 Broadcast,以及编号为 p/2 的设备对编号 p/2 到编号 p 的 Broadcast。我们便可以通过在这两个子问题上进行递归来完成这个算法,并把临界条件定义为编号为 i 的设备在 [i, i] 这个区间中的 Broadcast。此时,由于 i 本身已经拥有该信息,不需要做任何操作便可直接完成 Broadcast。这个优化后的算法为 (a+b×l)×log⁡p 时间复杂度,因为在算法的每一阶段(编号为 t),有 2^t 个设备在并行运行 Broadcast 算子。同时,算法一定会在 log⁡p 步之内结束。

3.1.3 Reduce

在分布式机器学习系统中,另一个常见的操作是将不同设备上的计算结果进行聚合(Aggregation)。例如,将每个设备计算的本地梯度进行聚合,计算梯度之和(Summation)。这些聚合函数(表达为 f)往往符合结合律(Associative Law)和交换律(Commutative Law)(不然也没法分发到不同的设备上计算)。这些函数由全部设备共同发起,最终聚合结果存在编号为 i 的设备上。常见聚合函数有加和、乘积、最大值和最小值。集合通信将这些函数表达为 Reduce 算子。上图中上半部分展示了设备 1 调用 Reduce 来进行加和的初始和结束状态。

一个简易的 Reduce 的优化实现同样可以用分治思想来实现,即把 1 到 p/2−1 的 Reduce 结果存到编号为 1 的设备中,然后把 p/2 到 p 的 Reduce 结果存到 p/2 上。最后,可以把 p/2 的结果发送至 1,执行 f,并把最后的结果存至 i。假设 f 的运行时间复杂度为常数并且其输出信息的长度 l 不改变,Reduce的时间复杂度仍然为 (a+b×l)×log⁡p。

3.1.4 AllReduce

集合通信通过引入 AllReduce 算子,从而将 Reduce 函数 f 的结果存至所有设备上。上图右上半部分

展示了设备 1,设备 2 和设备 3 共同调用 AllReduce 来进行加和的初始和结束状态。

一种简单的 AllReduce 实现方法是首先调用 Reduce 算法并将聚合结果存到编号为 1

的设备上。然后,再调用 Broadcast 算子将聚合结果广播到所有的设备。这种简单的 AllReduce 实现的时间复杂度为 (a+b×l)×log⁡p。

3.1.5 Gather

Gather算子可以将全部设备的数据全部收集(Gather)到编号为 i 的设备上(Gather 就是收集的意思)。上图左下部分展示了设备 1 调用 Gather 来收集全部设备的数据的初始和结束状态。

在收集函数(Gather Function)符合结合律和交换律的情况下,可以通过将其设为 Reduce 算子中的

f 来实现 Gather 算子。但是,在这种情况下,无论是基于链表还是数组的实现,在每一步的 Reduce 操作中 f 的时间复杂度和输出长度 l 都发生了改变。因此,Gather 的时间复杂度是 a×log⁡p+(p−1)×b×l。这是因为在算法的每一阶段 t,传输的信息长度为 2^t×l。

3.1.6 AllGather

AllGather 算子会把收集的结果分发到全部的设备上。上图中下部分展示了设备 1,设备 2 和设备 3 共同调用 AllGather 的初始和结束状态。

在这里,一个简单的方法是使用 Gather 和 Broadcast 算子把聚合结果先存到编号为 1 的设备中,再将其广播到剩余的设备上。这会产生一个 a×log⁡p+(p−1)×b×l+(a+p×l×b)×log⁡p 的时间复杂度,因为在广播时,如果忽略链表/数组实现所带来的额外空间开销,每次通信的长度为 pl 而不是 l。简化后,得到了一个 a×log⁡p+p×l×b×log⁡p 的时间复杂度。在一个基于**超立方体**的算法下,可以将其进一步优化到和 Gather 算子一样的时间复杂度 a×log⁡p+(p−1)×b×l。

3.1.7 Scatter

Scatter 算子可以被视作 Gather 算子的逆运算:把一个存在于编号为 i 的设备上,长度为 p(信息长度为 p×l)的链式数据结构 L 中的值分散到每个设备上,使得编号为 i 的设备会得到 L[i] 的结果。 上图右下部分展示了设备 1 调用 Scatter 的初始和结束状态。

可以通过模仿 Gather 算法设计一个简易的 Scatter 实现:每一步的运算中,我们把现在的子链继续对半切分,并把前半段和后半段作为子问题进行递归。这时候,在算法的每一阶段 t,传输的信息长度为 l×2^(m−t),其中 m 是算法总共运行的步骤,不会超过log⁡p (见Broadcast算子的介绍)。最终,Scatter 算子的简易实现和 Gather 算子一样都有 a×log⁡p+(p−1)×b×l 的时间复杂度。

3.2 基于 AllReduce 的梯度平均算法

利用 AllReduce 算子可以实现大型集群中的高效梯度平均。首先,可以考虑一种简单的计算平均梯度的方法:在集群中分配一个设备收集本地梯度,并在计算平均梯度后再将其广播到全部的设备。

这种做法易于实现,但是引入了两个问题。

  • 首先,多台设备同时给该聚合设备发送数据时,聚合设备会因严重的带宽不足产生网络拥塞。
  • 其次,单台设备需要负担大量的梯度平均计算,而受限于单台设备上的有限算力,这种计算往往会受限于算力瓶颈。

为了解决上述问题,可以引入 AllReduce 算子的 Reduce-Broadcast 实现来优化算法,其设计思路是:**通过让全部的节点参与到梯度的网络通信和平均计算中,将巨大的网络和算力开销均摊给全部节点。**这种做法可以解决先前单个梯度聚合节点的问题。假设有 M 个设备,每个设备存有一个模型副本,该模型由 N 个参数/梯度构成。那么按照 AllReduce 算子的要求,需要先将全部的参数按照设备数量切分成 M 个分区(Partition),使得每个分区具有 N/M 个参数(其实不一定要平均切分,这里只是为了讨论方便简化了)。

首先给出这个算法的初始和结束状态。如下图所示。该例子含有 3 个设备。在每个设备有一个模型副本的情况下,这个副本有 3 个参数。那么按照 AllReduce 的分区方法,参数会被划分成 3 个分区(3 个设备),而每一个分区则有 1 个参数(N/M,N 代表 3 个参数,M代表 3 个设备)。在这个例子中,假定设备 1 拥有参数 2, 4, 6,设备 2 拥有参数 1, 2, 3,设备 3 拥有参数 4, 8, 12,那么在使用一个 AllReduce 算子进行计算过后,全部的设备都将拥有梯度相加后的结果 7, 14, 21,其中分区 1的结果 7 是由 3 个设备中分区 1 的初始结果相加而成(7 = 1 + 2 + 4)。为了计算平均梯度,每个设备只需要在最后将梯度之和除以设备数量即可(分区 1 的最终结果为 7 除以 3)。

AllReduce 算子会把梯度的计算拆分成 M−1 个 Reduce 算子和 M−1 个 Broadcast 算子(其中 M

是节点的数量)。其中,Reduce 算子用于计算出梯度的加和,Broadcast 算子用于把梯度之和广播给全部的节点。上图展示了一个 AllReduce 算子的执行过程。AllReduce 算子由 Reduce 算子开始,在第一个 Reduce 算子中,AllReduce 算子会对全部节点进行配对(Pairing),让它们共同完成梯度相加的操作。在上图的第一个 Reduce 算子中,设备 1 和设备 2 进行了配对共同对分区 1 的数据相加。其中,设备 2 把本地的梯度数据 1 发送给设备 1,设备 1 将接收到的梯度数据 1 和本地的分区 1内的梯度数据 2 进行相加,计算出中间梯度相加的结果 3。与此同时,设备 1 和设备 3 进行配对,共同完成对分区 3 的数据相加。而设备 3 和设备 2 进行配对,共同完成对于分区 2 的数据相加。接下来的 Broadcast 过程类似。

3.3 利用集合通信优化模型训练实践

3.3.1 ZeRO

ZeRO是微软提出的神经网络优化器。训练大模型的时候容易发现需要的显存远超单卡的显存,因此需要把模型切分之后放到不同的计算节点上(数据并行)。Zero 有三个主要的关于集合通信的优化技术:

  • 单一节点上的参数存储:现代集群中节点内部加速器的带宽远大于节点之间的带宽。为此,需要尽量减少节点间的通信,并且保证大部分通信仅存在于节点内部的加速器之间。在观察模型切片时,又可得模型本身前向和反向计算时需要在不同切片之间进行的通信远小于不同模型副本梯度平均的通信量。针对这一特性,ZeRO 选择了将单一模型的全部切片存储到同一节点内部,从而大大提高了训练效率(也就是主要做数据并行)。
  • 基于 AllGather 的前向计算: 假设模型中的参数在层级上呈线性,便可按照参数在网络上的顺序从前到后将其分别存储到不同加速器中。在前向时,可以注意到某一层的计算仅依赖于其相邻层的参数。对此,可以对所有包含模型参数的加速器进行一次 AllGather 计算,用来提取每一层的后一层的参数,以及计算该层本身的激活值。为了节约内存,在 AllGather 操作结束后需要立即丢弃除了该层以外其他层的参数。
  • 基于 ReduceScatter 的梯度平均: 在反向计算时我们只需要前一层的参数来计算本层的激活值和梯度,因此只需要再次使用 AllGather 来完成每个加速器上的梯度计算。同时,在聚集梯度后,对于每个加速器仅需要和加速器的编号相同的层数对应的梯度。对此,可以使用 ReduceScatter 算子直接把相应的梯度存到编号为 i的加速器上,而不是通常情况下使用 AllReduce 算子。

3.3.2 DALL-E

DALL-E 是 OpenAI 提出的一个基于文字的图片生成模型,模型拥有高达 120 亿的参数。在训练时,除了运用到 ZeRO 所使用的 AllGather + ReduceScatter 技巧,OpenAI 团队在其他细节上做了进一步的优化。这里,介绍两个主要的关于集合通信的优化技术:

  • 矩阵分解: 集合通信算子的运行速度和信息本身的长度正相关。在模型训练中,这代表了模型参数本身的大小。对此,DALL-E 选择用矩阵分解(Matrix Factorization)的方法先把高维张量调整为一个二维矩阵,通过分解后分开用集合通信算子进行传输,从而大大减少了通信量。
  • 自定义数据类型: 一种减少通信量的方法在于修改数据类型本身。显然地,可以使用 16 位的半精度浮点数,相比正常的 32 位参数表示可以节省近一倍的通信量。但是,在实践中发现低精度的数据类型会使得模型收敛不稳定,导致最终训练效果大打折扣。为此,OpenAI 分析了 DALL–E 的模型结构,并把其中的参数根据对数据类型精度的敏感性分为了三类。其中对精度最敏感的一类照常使用 32 位浮点表示并只通过 AllReduce 算子来同步,而最不敏感的参数则照常通过矩阵分解进行压缩和传输。也就是把对模型效果影响较大的参数用高精度存储,影响较小的参数用低精度存储。

4. 参数服务器

4.1 系统架构

不同于基于集合通信实现的机器学习系统,参数服务器系统中的服务器会被分配两种角色:训练服务器和参数服务器。 也就是有一类服务器专门用来更新和分发参数,而不做实际的训练计算。参数服务器需要提供充足内存资源和通信资源,训练服务器需要提供大量的计算资源(如硬件加速器)。

下图是一个带有参数服务器的机器学习集群。这个集群中含有两个训练服务器和两个参数服务器。 假设我们有一个模型,可以切分为两个参数分区。每个分区被分配给一个参数服务器负责参数同步。 在训练的过程中,每个训练服务器都会有完整的模型,根据本地的训练数据集切片(Dataset Shard)训练出梯度。这个梯度会被推送(Push)到各自的参数服务器。参数服务器等到两个训练服务器都完成梯度推送,开始计算平均梯度,更新参数。它们然后通知训练服务器来拉取(Pull)最新的参数,开始下一轮训练迭代。

4.2 异步训练

参数服务器的一个核心作用是可以处理分布式训练服务器中出现的落后者(Straggler)。在之前的讨论中,在每一轮训练结束后,训练服务器都需要计算平均梯度对每一个模型副本进行更新,从而保证下一轮训练开始前,全部模型副本参数的一致性,这种对于参数一致性的确保一般被称为同步训练(Synchronous Training)。同步训练一般有助于训练系统达到更好的模型精度,但是当系统规模变大,往往会观察到落后者服务器的出现。落后者出现的原因很多。常见的原因包括:落后者设备可能和其他设备不在同一个机柜中,因此落后者的通信带宽显著小于其他设备。另外,落后者设备也可能和其他进程共享本地的服务器计算和通信资源,形成资源竞争,从而降低了性能。

落后者对于基于 AllReduce 的同步训练系统的性能有显著影响,这是因为 AllReduce 让全部节点参与到平均梯度的计算和通信中,而每个节点负责等量的数据。因此一个落后者的出现,都会让整个 AllReduce 操作延迟完成。为了解决这个问题,人们常使用参数服务器同步梯度。一种常见的设计是:训练服务器训练出梯度后,会把本地梯度全部推送到参数服务器。参数服务器在等到一定训练服务器(例如 90% 的训练服务器)的梯度后,就开始计算平均梯度。这样可以确保平均梯度的计算不会被落后者的出现延误。计算好的平均梯度马上推送给全部训练服务器,开始下一轮训练。

**解决落后者的另一种常见做法是利用参数服务器实现异步训练(Asynchronous Training)。**在一个异步训练系统中,每个训练服务器在训练开始时,有相同的模型参数副本。在训练中,它们计算出梯度后会马上将梯度推送到参数服务器,参数服务器将推送的梯度立刻用于更新参数,并通知训练服务器立刻来拉取最新的参数。**在这个过程中,不同的训练服务器很可能会使用不同版本的模型参数进行梯度计算,这种做法可能会伤害模型的精度,但它同时让不同训练服务器可以按照各自的运算速度推送和拉取参数,而无须等待同伴,因此避免了落后者对于整个集群性能的影响。**也就是用精度换速度。

5. 参考

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/空白诗007/article/detail/1011857
推荐阅读
相关标签
  

闽ICP备14008679号