当前位置:   article > 正文

简明Pytorch分布式训练 — DistributedDataParallel 实践

简明Pytorch分布式训练 — DistributedDataParallel 实践

上一次的Pytorch单机多卡训练主要介绍了Pytorch里分布式训练的基本原理,DP和DDP的大致过程,以及二者的区别,并分别写了一个小样作为参考。小样毕竟还是忽略了很多细节和工程实践时的一些处理方式的。实践出真知,今天(简单)写一个实际可用的 DDP 训练中样,检验一下自己对 DDP 的了解程度,也作为以后分布式训练的模板,逐渐完善。

DDP 流程回顾

Pytorch单机多卡训练中主要从DDP的原理层面进行了介绍,这次从实践的角度来看看。

试问把大象装进冰箱需要几步?三步,先打开冰箱,再把大象装进去,最后关上冰箱门。那实践中DDP又分为几步呢?

  1. 开始训练/预测前的准备工作。
  2. 开始训练/验证!

看,DDP比把大象装进冰箱简单多了!

一些前置

虽然DDP很简单,但是得先了解一些前置知识。有一些DDP相关的函数是必须要了解的1,主要是和分布式通信相关的函数:

  • torch.distributed.init_process_group2。初始化分布式进程组。由于是单机多卡看的分布式,因此其实是将任务分布在多个进程中。
  • torch.distributed.barrier()3。用于进程同步。每个进程进入这个函数后都会被阻塞,当所有进程都进入这个函数后,阻塞解除,继续执行后续的代码。
  • torch.distributed.all_gather4。收集不同进程中的tensor。这个用到的比较多的,但有一点不太好理解。解释一下:某些变量是不同进程中都会有的,比如loss,如果要把各个进程中计算得到的loss汇总到一起,就需要进程间通信,把loss收集起来,all_gather干的就是收集不同进程中的某个tensor的事儿。
  • local_rank。节点上device的标识。在单机多卡的模式下,每个device的local_rank是唯一的,一般我们都在0号设备上操作。该参数不需要手动传参,在执行DDP时会自动设置该参数。当为非DDP时,该参数值为-1。
  • torch.nn.parallel.DistributedDataParallel5。在初始化好的分布式环境中,创建分布式的模型,负责。

好了,你已经具备DDP实践的主要前置知识了,一起开启DDP实践吧!

DDP 实践

训练前

训练前的准备工作,主要包含:

  • 初始化分布式进程组,使后续各个进程间能够通信。
  • 准备好数据和模型。
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600))
args.device = device

X, Y = make_classification(n_samples=25000, n_features=N_DIM, n_classes=2, random_state=args.seed)  # 每个进程都会拥有这样一份数据集
X = torch.tensor(X).float()
Y = torch.tensor(Y).float()

train_X, train_Y = X[:20000], Y[:20000]
val_X, val_Y = X[20000:], Y[20000:]
train_dataset = SimpleDataset(train_X, train_Y)
val_dataset = SimpleDataset(val_X, val_Y)

model = SimpleModel(N_DIM)
if args.local_rank not in (-1, 0):
    torch.distributed.barrier()
else:
    # 只需要在 0 号设备上加载预训练好的参数即可,后续调用DDP时会自动复制到其他卡上
    if args.ckpt_path is not None:
        model.load_state_dict(torch.load(args.ckpt_path))
    torch.distributed.barrier()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

训练时

与单卡训练不同,DDP训练主要有三个改动地方:

  • 创建数据加载器时,需要创建特定的采样器,以保证每个进程使用的数据是不重复的
  • 创建分布式的模型。
  • 收集一些辅助信息时,例如loss、prediction等,需要all_gather将不同进程中的数据收集到一起。
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False)

train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size)

...

def gather_tensors(target):
    """
    target should be a tensor
    """
    target_list = [torch.ones_like(target) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(target_list, target)
    ret = torch.hstack(target_list)
    
    return ret

for epoch in train_iterator:
    if args.local_rank != -1:
        train_dataloader.sampler.set_epoch(epoch)  # 必须设置的!确保每一轮的随机种子是固定的!

    model.zero_grad()
    model.train()

    pbar = tqdm(train_dataloader, desc="Training", disable=args.local_rank not in [-1, 0])
    step_loss, step_label, step_pred = [], [], []
    for step, batch in enumerate(pbar):
        data, label = batch
        data, label = data.to(args.device), label.to(args.device)

        prediction = model(data)
        loss = loss_func(prediction, label.unsqueeze(1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1

        step_loss.append(loss.item())
        step_label.extend(label.squeeze().cpu().tolist())
        step_pred.extend(prediction.squeeze().detach().cpu().tolist())

        if step % args.logging_step == 0:
            if torch.distributed.is_initialized():  # 兼容非 DDP 的模式
                step_loss = torch.tensor(step_loss).to(args.device)
                step_label = torch.tensor(np.array(step_label)).to(args.device)
                step_pred = torch.tensor(np.array(step_pred)).to(args.device)
                step_loss = gather_tensors(step_loss)
                step_label = gather_tensors(step_label).squeeze().cpu().numpy()
                step_pred = gather_tensors(step_pred).squeeze().cpu().numpy()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

单卡训练时,收集一些辅助信息是很方便的,因为所有的数据都在同一进程中,所以不需要同步操作。在DDP模式下,每个进程使用不同的数据训练,为了记录训练过程的信息以及计算指标等情况,需要把所有数据的结果收集起来,这就需要用到all_gather了。

启动脚本

一种常用的DDP启动方式:

#! /usr/bin/env bash
node_num=4

# DDP 模式
CUDA_VISIBLE_DEVICES="0,1,2,3" python -m torch.distributed.launch --nproc_per_node=$node_num --master_addr 127.0.0.2 --master_port 29502 ddp-demo.py \
--do_train \
--do_test \
--epochs 3 \
--batch_size 128 \
--lr 5e-6 \
--output_dir ./test_model \
--seed 2024
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

python -m torch.distributed.launch即把torch.distributed.launch这个模块作为一个脚本运行,后面接的 --nproc_per_node=$node_num --master_addr 127.0.0.2 --master_port 29502 ddp-demo.py都作为这个模块的参数,关于该模块的参数以及含义,可以参考:https://github.com/pytorch/pytorch/blob/main/torch/distributed/launch.py。

更上一层楼

果然,DDP训练比把大象装进冰箱要简单多了。

当然,这主要归功于Pytorch封装的好,还有很多细节是值得我们深究的。以下是一些我觉得可以进一步探索的:

  • 分布式训练时不同的后端之间的区别,比如gloompinccl。主要是加深对多进程通信的了解。
  • torch.nn.parallel.DistributedDataParallel内部的细节,比如梯度的同步过程及涉及到的算法等。
  • DistributedSampler是如何实现不同device读取不同的数据的。这一块其实原理很简单,源码也不多。
  • 多机多卡的实践6

结语

单卡或DP训练相比,DDP启动多个进程,每个进程读取对应的数据,每个step单独训练,能够大大提高显卡的利用率和降低训练时间。

初次接触DDP还是很容易让人茫然的,如果有过多进程的经验的会更好理解其中的一些概念以及分布式训练的流程。

目前写的这个demo虽然比较简单,但常用的内容都涉及到了,后续可以在这个版本上继续完善。

附完整代码:

import os
import time
import random
import logging
import argparse
import datetime
from tqdm import tqdm, trange

import numpy as np
from sklearn import metrics
from sklearn.datasets import make_classification

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


N_DIM = 200
logging.basicConfig(format='[%(asctime)s] %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m-%d %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger("DDP")

class SimpleModel(nn.Module):
    def __init__(self, input_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))

class SimpleDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

def set_seed(args):
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)
    
def parse_args():
    parser = argparse.ArgumentParser()
    ## DDP:从外部得到local_rank参数。从外面得到local_rank参数,在调用DDP的时候,其会自动给出这个参数
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--no_cuda", action="store_true", help="Whether to cuda.")
    parser.add_argument("--seed", default=2024, type=int)
    
    parser.add_argument("--ckpt_path", default="./ddp-outputs/checkpoint-100/model.pt", type=str)
    parser.add_argument("--output_dir", default="./ddp-outputs/", type=str)
    
    parser.add_argument("--epochs", default=3, type=int)
    parser.add_argument("--lr", default=1e-5, type=float)
    parser.add_argument("--batch_size", default=2, type=int)
    parser.add_argument("--val_step", default=50, type=int)
    parser.add_argument("--save_step", default=10, type=int)
    parser.add_argument("--logging_step", default=2, type=int)
    
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.")
    
    args = parser.parse_args()
    return args


def gather_tensors(target):
    """
    target should be a tensor
    """
    target_list = [torch.ones_like(target) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(target_list, target)
    # ret = torch.cat(target_list, 0)
    ret = torch.hstack(target_list)
    
    return ret


def train(args, model, train_dataset, val_dataset=None):
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    
    model = model.to(args.device)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False)
    
    logger.info(f"After DDP, Rank = {args.local_rank}, weight = {model.module.fc.weight.data[0, :10]}")
    
    # 采样器
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size)
    
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    loss_func = nn.BCELoss().to(args.local_rank)
    
    total_steps = len(train_dataloader) * args.epochs
    global_step = 0
    train_iterator = trange(0, int(args.epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    for epoch in train_iterator: #range(args.epochs):  # train_iterator:
        if args.local_rank != -1:
            train_dataloader.sampler.set_epoch(epoch)
        
        if args.local_rank in (-1, 0):
            logger.info('*' * 25 + f" Epoch {epoch + 1} / {args.epochs} " + '*'* 25)
        
        model.zero_grad()
        model.train()
        
        pbar = tqdm(train_dataloader, desc="Training", disable=args.local_rank not in [-1, 0])
        step_loss, step_label, step_pred = [], [], []
        for step, batch in enumerate(pbar):
            data, label = batch
            data, label = data.to(args.device), label.to(args.device)
            
            prediction = model(data)
            loss = loss_func(prediction, label.unsqueeze(1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1
            time.sleep(0.1)
            
            step_loss.append(loss.item())
            step_label.extend(label.squeeze().cpu().tolist())
            step_pred.extend(prediction.squeeze().detach().cpu().tolist())
            
            if step % args.logging_step == 0:
                if torch.distributed.is_initialized():  # 兼容非 DDP 的模式
                    step_loss = torch.tensor(step_loss).to(args.device)
                    step_label = torch.tensor(np.array(step_label)).to(args.device)
                    step_pred = torch.tensor(np.array(step_pred)).to(args.device)
                    step_loss = gather_tensors(step_loss)
                    step_label = gather_tensors(step_label).squeeze().cpu().numpy()
                    step_pred = gather_tensors(step_pred).squeeze().cpu().numpy()
                
                if args.local_rank in (-1, 0):
                    # logger.info(f"Gathered loss = {step_loss}")
                    # logger.info(f"Gathered label = {step_label}")
                    # logger.info(f"Label shape = {step_label.shape}, Pred = {step_pred}, {prediction.shape}")
                    if not all(step_label == 1) and not all(step_label == 0):
                        auc = metrics.roc_auc_score(step_label, step_pred)
                        # logger.info(f"Step AUC = {auc:.5f}")
                        pbar.set_description(f"loss={step_loss.mean():>.5f}, auc={auc:.5f}")
                
                step_loss, step_label, step_pred = [], [], []
            
            if step % args.val_step == 0 and val_dataset is not None:
                test_val(args, model, val_dataset)
            
            if global_step % args.save_step == 0 and global_step > 0 and args.local_rank in [-1, 0]:
                save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                model_to_save = model.module if hasattr(model, 'module') else model
                torch.save(model_to_save.state_dict(), os.path.join(save_path, "model.pt"))
                torch.save(optimizer.state_dict(), os.path.join(save_path, "optimizer.pt"))
            
                
def test_val(args, model, dataset):
    # 采样器
    test_sampler = RandomSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
    test_dataloader = DataLoader(dataset, sampler=test_sampler, batch_size=args.batch_size)
    
    ground_truth = []
    prediction = []
    pbar = tqdm(test_dataloader, desc="Testing", disable=args.local_rank not in [-1, 0])
    with torch.no_grad():
        for step, batch in enumerate(pbar):
            data, label = batch
            data, label = data.to(args.device), label.to(args.device)
            pred = model(data)
            time.sleep(0.05)

            ground_truth.extend(label.cpu().squeeze().tolist())
            prediction.extend(pred.cpu().squeeze().tolist())

        ground_truth = torch.tensor(ground_truth).to(args.device)
        prediction = torch.tensor(prediction).to(args.device)
    
        if args.local_rank != -1:  # 等待所有进程预测完,再进行后续的操作
            torch.distributed.barrier()

        if torch.distributed.is_initialized():  # 由于每张卡上都会有 prediction 和 ground_truth,为了汇集所有卡上的信息,需要 torch.distributed.all_gather
            ground_truth = gather_tensors(ground_truth)
            prediction = gather_tensors(prediction)
    
        if args.local_rank in [-1, 0]:
            print(f"GT : {ground_truth.shape}\tPred : {prediction.shape}")
            auc = metrics.roc_auc_score(ground_truth.detach().cpu(), prediction.detach().cpu())
            logger.info(f"Test Auc: {auc:.5f}")
        
     
def main(args):
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600))
        args.n_gpu = 1
    args.device = device
    
    if args.local_rank in (-1, 0):
        logger.info('*' * 66)
        for k, v in vars(args).items():
            print(f"########## {k:>20}:\t{v}")
        logger.info('*' * 66)
    
    set_seed(args)
    
    model = SimpleModel(N_DIM)
    if args.local_rank not in (-1, 0):
        logger.info(f"Barrier in {args.local_rank}")
        torch.distributed.barrier()  # barrier 用于进程同步。每个进程进入这个函数后都会被阻塞,当所有进程都进入这个函数后,阻塞解除。
    else:
        if args.ckpt_path is not None:
            logger.info(f"Load pretrained model in {args.local_rank}")
            model.load_state_dict(torch.load(args.ckpt_path))
        logger.info(f"Barrier in {args.local_rank}")
        torch.distributed.barrier()
    logger.info(f"After barrier in {args.local_rank}")
    
    logger.info(f"Before DDP, Rank = {args.local_rank}, weight = {model.fc.weight.data[0, :10]}")
    
    X, Y = make_classification(n_samples=25000, n_features=N_DIM, n_classes=2, random_state=args.seed)  # 每个进程都会拥有这样一份数据集
    X = torch.tensor(X).float()
    Y = torch.tensor(Y).float()
    
    # 通过这个可以看到,在不同进程上的 X 的 id 是不同的,但是内容是一致的
    logger.info(f"Rank = {args.local_rank}, id(X) = {id(X)}, Y[:20] = {Y[:20]}")  
    
    train_X, train_Y = X[:20000], Y[:20000]
    val_X, val_Y = X[20000:], Y[20000:]
    train_dataset = SimpleDataset(train_X, train_Y)
    val_dataset = SimpleDataset(val_X, val_Y)
    
    if args.do_train:
        train(args, model, train_dataset, val_dataset)
    
    if args.do_test:
        test_val(args, model, val_dataset)


if __name__ == "__main__":
    args = parse_args()
    main(args)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261

在这里插入图片描述

Reference


  1. DISTRIBUTED COMMUNICATION PACKAGE - TORCH.DISTRIBUTED. ↩︎

  2. 官方文档:torch.distributed.init_process_group. ↩︎

  3. 官方文档:torch.distributed.barrier. ↩︎

  4. 官方文档:torch.distributed.all_gather. ↩︎

  5. 官方文档-torch.nn.parallel.DistributedDataParallel. ↩︎

  6. Multi Node PyTorch Distributed Training Guide For People In A Hurry. ↩︎

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/344573
推荐阅读
相关标签
  

闽ICP备14008679号