当前位置:   article > 正文

Facebook的ZeRO算法原理及简单代码实验(小显卡训大模型)_zero通信量分析

zero通信量分析

1、MP效果已经很好了,为什么还要ZeRO?

模型并行(MP:https://zhuanlan.zhihu.com/p/366906920): 将模型横向或垂直分割,将计算和参数划分到每一层,跨多个设备,需要每一层之间的重要通信。在 GPU 之间通信带宽高的单个节点内工作良好,但跨节点工作会较慢。
ZeRO:同样是跨多个设备,将模型和参数放到不同设备,但是通信量却大大减少。

2、那ZeRO具体是怎么优化的呢?

2.1 优化器回顾

SGD: 没有动量概念
在这里插入图片描述

Adam优化器:Adam在SGD基础上,为每个参数梯度增加了一阶动量(momentum)和二阶动量(variance)

在这里插入图片描述

2.2 显存去哪了

  1. 模型状态(model states): 模型参数(fp16)、模型梯度(fp16)和Adam状态(fp32的模型参数备份,fp32的momentum和fp32的variance)。
  2. 剩余状态(residual states): 除了模型状态之外的显存占用,包括激活值(activation)、各种临时缓冲区(buffer)以及无法使用的显存碎片(fragmentation)。
    问题一:
    GPT-2(2B)在混合精度训练的情况下,放到显卡需要的内存是多少?

2.3 ZeRO原理

针对模型状态的存储优化,ZeRO使用的方法是分片,即每张卡只存 1/N 的模型状态量,这样系统内只维护一份模型状态。
ZeRO-1:optimizer分片
ZeRO-2:optimizer + Gradient分片
ZeRO-3:optimizer + Gradient + model分片
在这里插入图片描述

VS DDP

在这里插入图片描述

1.3.1 ZeRO-1

解决的问题:Optimizer state的冗余。
没有即没有将模型本身进行分,也没有将Gradient进行分片,而是只将优化器进行分片。
过程动画:https://zhuanlan.zhihu.com/p/394064174 (原文链接)
动画链接(可能会过期):https://vdn6.vzuu.com/SD/bd03b0bc-ef95-11eb-8ee1-ce96bf022449.mp4?pkey=AAWvE2ChU9kHMLO_n5M8CJeSDHqfRZRy2dMrPU4eOnfzHOWOaGYoGlxJIEAzhzJT4Fgsk-wW1oESc3ngsFZcCFRM&c=avc.0.0&f=mp4&pu=078babd7&bu=078babd7&expiration=1660109505&v=ks6
在这里插入图片描述

  • 训练过程与DDP类似。forward过程由每个rank的GPU独自完整的完成,然后进行backward过程。在backward过程中,梯度通过allReduce进行同步。
  • Optimizer state 使用贪心策略基于参数量进行分片,以此确保每个rank几乎拥有相同大小的优化器内存。
  • 每个rank只负责更新当前优化器分片的部分,由于每个rank只有分片的优化器state,所以当前rank忽略其余的state。
  • 在更新过后,通过广播或者allGather的方式确保所有的rank都收到最新更新过后的模型参数。
  • ZeRO-1 非常适合使用类似Adam进行优化的模型训练,因为Adam拥有额外的参数m(momentum)与v(variance),特别是FP16混合精度训练。
  • ZeRO-1 不适合使用SGD类似的优化器进行模型训练,因为SGD只有较少的参数内存,并且由于需要更新模型参数,导致额外的通讯成本。
1.3.2 ZeRO-2

解决的问题:gradient 的冗余。
为了减少梯度Gradient冗余以此进一步节省内存,ZeRO-2提出gradient sharding,在FairScale里称之为Sharded Data Parallel(SDP)。相比与ZeRO-1, ZeRO-2除了对optimizer state进行切分,还对Gradient进行了切分。

  • 像ZeRO-1 一样将optimizer的参数进行分片,并安排在不同的rank上。
  • 在backward过程中,gradients被reduce操作到对应的rank上,取代了all-reduce,以此减少了通讯开销。
  • 每个rank独自更新各自负责的参数。
  • 在更新操作之后,广播或allGather保证所有的ranks接受到更新后的参数。
    在这里插入图片描述

1.3.3 ZeRO-3
解决的问题:model参数分割。
为了进一步节省更多的内存,ZeRO-3提出进行模型参数的分片。类似以上两种分片方式,ranks仅负责模型参数的切片。可以进行参数切片的原因主要有以下两点:

  • AllReduce操作可以被拆分为Reduce与allgather操作的结合。
  • 模型的每一层拥有该层的完整参数,并且整个层能够直接被一个GPU装下。所以计算前向的时候,除了当前rank需要的层之外,其余的层的参数可以抛弃。
    过程动画:https://vdn6.vzuu.com/SD/2a0e318c-ef96-11eb-9cd5-6ad0d31fb0b0.mp4?pkey=AAWtWJs2zkZXzKwAdUt7rtD8scwCEPuOBq7Pn7VLpGstNISogXgsust_iUhM7RpuZG1rzqTPelWdmGfW_PpBk0lw&c=avc.0.0&f=mp4&pu=078babd7&bu=078babd7&expiration=1660046771&v=ks6
    在这里插入图片描述

1.4 ZeRO通信分析

  • DDP:reduce-scatter + all-gather,分别需要Ψ的通信量,每gpu共计消耗2Ψ通信量。
  • Pos、Pos+g:reduce-scatter + all-gather,通信量每gpu共计消耗2Ψ。
  • Pos+g+p:需要对梯度进行一次reduce-scatter操作(因为每个gpu各自负责部分参数的更新,因此不需要对梯度进行all-gather操作),对参数需要进行正向和反向两次传递,所以需要消耗2Ψ通信量,共计每gpu消耗3Ψ通信量。

问题二:请分析分析MP @方恺齐 的通信量?

1.5 FSDP源码分析

https://github.com/pytorch/pytorch/blob/b91ff5e361623685799b8ef725a91b756685a9ae/torch/distributed/fsdp/fully_sharded_data_parallel.py#L462
代码问题:
1、FSDP对模型参数是怎么进行操作和存储的,和普通的DDP有什么不同?
将参数拉平,并存储了每个参数的size
2、FSDP分别是在哪个函数实现模型参数分割和合并的?
_sharded_parameters
3、对于模型参数不能均分的情况,FSDP采用了什么策略?
pad
https://github.com/pytorch/pytorch/blob/e81664449559f95d0b8d0fe57d66544a0ab84fe8/torch/distributed/fsdp/fully_sharded_data_parallel.py#L3237

1.6 实际应用

1.6.1 实验代码
import argparse
import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel as SDP 
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler


scaler = GradScaler()


class MyModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, inner_dim, hidden_dim, num_choices, nlayers=2):
        super().__init__()
        self.nlayers = nlayers
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.linear = nn.Linear(embed_dim, hidden_dim)
        self.fn = nn.Sequential(nn.Linear(hidden_dim, inner_dim), 
                                    nn.ReLU(),
                                    nn.Linear(inner_dim, hidden_dim))
        self.drop = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, num_choices)
        
    def forward(self, input_ids):
        embed = self.embed(input_ids)
        v = self.linear(embed)
        v = self.fn(v)
        last_token_hidden = v[:, -1]
        last_token_hidden = self.drop(last_token_hidden)
        logits = self.classifier(last_token_hidden)
        return logits


def initialize_distributed(args):
    """Initialize torch.distributed."""

# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
print(f'rank = {args.rank} || local_rank = {args.local_rank}')
if args.local_rank is not None:
    device = args.local_rank
torch.cuda.set_device(device)
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
print(f'init_method = {init_method}')
dist.init_process_group(backend='nccl',
                        world_size=args.world_size,
                        rank=args.rank,
                        init_method=init_method)
dist.all_reduce(torch.zeros(1).cuda())

parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--autocast', action='store_true', help='Run in pytorch autocast mode.')
parser.add_argument('--zero3', action='store_true', help='Run in pytorch autocast mode.')
parser.add_argument('--zero2', action='store_true', help='Run in pytorch autocast mode.')
parser.add_argument('--zero1', action='store_true', help='Run in pytorch autocast mode.')

parser.add_argument('--local_rank',
                    type=int,
                    default=None,
                    help='local rank passed from distributed launcher')

args = parser.parse_args()

args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))

initialize_distributed(args)

total_steps = 10000000
batch_size = 1
vocab_size = 20000
data_len = 512
embed_dim = 10000
inner_dim = 10000
hidden_dim = 20000
num_choices = 10000
loss_fn = nn.CrossEntropyLoss()

device = "cuda:{}".format(torch.cuda.current_device())

model = MyModel(vocab_size, embed_dim, inner_dim, hidden_dim, num_choices)
model.to(device)

n_all_param = sum([p.nelement() for p in model.parameters()])

if args.rank == 0:
    print(f'n_all_param: {n_all_param}')
    for k, v in model.named_parameters():
        print(f'rank: {args.rank} --- {k} shape: {v.shape}')

if args.zero1:
    base_optimizer_arguments = {'lr':0.05}
    base_optimizer = torch.optim.Adam
    optimizer = OSS(
        params=model.parameters(),
        optim=base_optimizer,
        **base_optimizer_arguments)
    model = DDP(model)

elif args.zero2:
    base_optimizer_arguments = {'lr':0.05}
    base_optimizer = torch.optim.Adam 
    optimizer = OSS(
        params=model.parameters(),
        optim=base_optimizer,
        **base_optimizer_arguments)
    model = SDP(model, optimizer)
    if args.autocast:
        from fairscale.optim.grad_scaler import ShardedGradScaler
        scaler = ShardedGradScaler()

elif args.zero3:
    print(f'zero3 ----')
    optimizer = optim.Adam(model.parameters(), lr=0.05)
    model = FSDP(model, mixed_precision=True)

else:
    optimizer = optim.Adam(model.parameters(), lr=0.05)
    model = DDP(model)

if args.rank == 0:
    for k, v in model.named_parameters():
        print(f'rank: {args.rank} --- {k} shape: {v.shape}')

for step in range(total_steps):
    model.zero_grad()
    data = torch.randint(vocab_size, (batch_size, data_len)).to(device) 
    labels = torch.randint(2, [batch_size]).to(device) 
    if args.autocast:
        with autocast():
            logits = model(data) 
        loss = loss_fn(logits, labels)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()
    else:
        logits = model(data) 
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()

    if args.rank == 0:
        print(f'step: {step} loss: {loss.item()}')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156

运行上述代码脚本: CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=3 python -W ignore -m torch.distributed.launch --nproc_per_node 2 --master_addr 127.0.0.1 --master_port 8927 fairscale_fsdp.py

结果:
在这里插入图片描述

选答题:
GPT-2(1.5B)在ZeRO-1,ZeRO-2,ZeRO-3模式,2张显卡的情况下,每张显卡内存分别是多少?

1.6.2 基于Fairscale库的使用代码:

ZeRO-1

from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP


base_optimizer_arguments = {'lr':0.05}
base_optimizer = torch.optim.Adam

optimizer = OSS(
        params=model.parameters(),
        optim=base_optimizer,
        **base_optimizer_arguments)
model = DDP(model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

ZeRO-2

from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as SDP 

base_optimizer_arguments = {'lr':0.05}
base_optimizer = torch.optim.Adam 
optimizer = OSS(
    params=model.parameters(),
    optim=base_optimizer,
    **base_optimizer_arguments)
model = SDP(model, optimizer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

ZeRO-3

from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP

optimizer = optim.Adam(model.parameters(), lr=0.05)
model = FSDP(model, mixed_precision=True)
  • 1
  • 2
  • 3
  • 4
  • 5

1.6.3 补充部分

下面实验是基于上面的1.6.1实验代码进行的
将整个模型包在一个FSDP

model = FSDP(model)
  • 1

在这里插入图片描述

将每个参数分别包一个FSDP

model.embed = FSDP(model.embed)
model.linear = FSDP(model.linear)
model.fn = FSDP(model.fn)
model.classifier = FSDP(model.classifier)
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

选答题:怎么将模型参数恢复呢?

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/345789
推荐阅读
相关标签
  

闽ICP备14008679号