赞
踩
全切片数据并行(Fully Sharded Data Parallel,简称为FSDP)是数据并行的一种新的方式,FSDP最早是在2021年在FairScale-FSDP中提出的,后来合入了PyTorch 1.11版本中。微软之前Deepspeed框架中提出过三种级别的ZERO算法,FSDP可以看成是ZERO-3
的实现。
传统的数据并行(DDP)是在每一个GPU卡上保存整个model的参数/梯度/优化器状态, 然后对数据集切分为
N
N
N 个shard分片给不同的GPU进行训练,计算完梯度后通过all-reduce通信来做梯度的融合。如下图:
在FSDP中的主要思路是想办法把model的梯度/优化器状态/参数都进行切分操作,每个GPU只存部分的参数信息,也就是在ZERO-3
的思路。为了能把所有的参数进行分片处理,核心在于要把DDP中的all-reduce操作拆解为reduce-scatter和all-gather 操作。
如下图,在进行FSDP前向计算其中的一层Layer时,由于每个GPU都只保存了部分参数,所以需要先通过all-gather操作获得全部的参数;同理,在反向计算过程中,也需要通过all-gather操作,获得全部的参数;最后计算出来的梯度只是部分的结果,需要通过reduce-scatter通信进行累加操作,最终每个GPU卡分别只更新自己那部分参数(也就是local本地weight更新)。
FSDP的应用是对原有model layers加上了一层wrapper封装,只有在FSDP实例中的layer才会在前向和后向过程中执行gather相关操作,通过切分可以利用相同的显存大小训练更大的模型。为了进一步提升显存利用率,FSDP也支持把不活跃的实例全部offload调出到CPU上去。
FSDP计算过程的伪码如下:
FSDP forward pass:
for layer_i in layers:
all-gather full weights for layer_i
forward pass for layer_i
discard full weights for layer_i
FSDP backward pass:
for layer_i in layers:
all-gather full weights for layer_i
backward pass for layer_i
discard full weights for layer_i
reduce-scatter gradients for layer_i
在PyTorch中的示例如下, 通过FullyShardedDataParallel
实现对model的封装,通过CPUOffload
来决定采用哪种策略把参数调到CPU上。
from torch.distributed.fsdp import ( FullyShardedDataParallel, CPUOffload, ) from torch.distributed.fsdp.wrap import ( default_auto_wrap_policy, ) import torch.nn as nn class model(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Linear(8, 4) self.layer2 = nn.Linear(4, 16) self.layer3 = nn.Linear(16, 4) model = DistributedDataParallel(model()) fsdp_model = FullyShardedDataParallel( model(), fsdp_auto_wrap_policy=default_auto_wrap_policy, cpu_offload=CPUOffload(offload_params=True), )
使用FSDP训练GPT-175B和GPT-1T参数量大小的模型,词表大小50K,fp16的精度和使用SGD的优化器。
结果如下,使用FSDP时在GPU卡数增大的情况下,对GPU单卡的吞叶没有影响;在A100-40G机器下增大batch_size
但吞吐没有增加, 瓶颈不在于通信而是CUDA cache的分配到了瓶颈;当换为A100-80G机器时,CUDA cache的分配问题得到解决后,增大batch_size
后吞吐进一步增加。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。