Pytorch多GPU并行训练: DistributedDataParallel_pytorch 多gpu训练 distributeddataparallel

1 模型并行化训练

1.1 为什么要并行训练


1.2 并行化训练策略


  • 模型并行


  • 数据并行



1.3 单机多卡与多级多卡

在深度学习和其他高性能计算任务中,"单机多卡"(Single-Node Multi-GPU)和"多机多卡"(Multi-Node Multi-GPU)是两种常见的硬件配置,它们涉及使用多个图形处理单元(GPUs)来加速计算。单机多卡配置通常更容易管理和维护,而多机多卡配置提供了更高的计算能力和扩展性,但也带来了更高的复杂度和成本。

2.1.1 单机多卡 (Single-Node Multi-GPU)

  • 定义:所有的 GPU 都安装在同一台机器上。

  • 通信:GPU之间通过PCIe总线或者更高带宽的NVLink进行通信。

  • 适用性:适合中等规模的数据集和模型,通常用于实验室环境或小规模的商业应用。

  • 设置复杂度:相对简单,因为所有的通信都在一个节点内部进行。

  • 扩展性:受限于单个节点能够支持的最大GPU数量。

  • 示例场景:在一个数据中心的单个服务器上训练深度学习模型。

2.1.2 多机多卡 (Multi-Node Multi-GPU)

  • 定义:GPU 分布在多台机器上,这些机器通过网络连接。

  • 通信:机器之间的通信通过高速网络(例如InfiniBand)进行,但比单节点内部的通信要慢。

  • 适用性:适合大规模数据集和模型,通常用于大型数据中心或复杂的机器学习任务。

  • 设置复杂度:更复杂,需要管理节点间的网络通信和同步。

  • 扩展性:理论上可以通过增加更多节点来无限扩展。

  • 示例场景:在多个数据中心分布的服务器上训练大型深度学习模型,如训练大型语言模型或复杂的科学计算任务。

 2 使用DistributedDataParallel实现模型并行化训练

2.1 基本概念







  • 第一种在启动程序时不需要在命令行输入额外的参数,写起来也比较容易,但是调试较麻烦,比如MAE;
  • 第二种必须要用命令行启动,写起来略微复杂,但是调试较方便。

2.2 torch.distributed分布式训练步骤

2.2.1 导入分布式模块


import torch.distributed as dist

2.2.2 用argparse编写模型的个性化参数

  1. parser = argparse.ArgumentParser()
  2. ''' ...your params '''
  3. ''' ...distributed params'''
  4. # 开启的进程数,不用设置该参数,会根据nproc_per_node自动设置
  5. parser.add_argument('--world-size', default=4, type=int, help='number of distributed processes')
  6. parser.add_argument('--local_rank', type=int, help='rank of distributed processes')
  7. opt = parser.parse_args()


2.2.3 初始化distributed


  1. # 初始化各进程环境
  2. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  3. args.rank = int(os.environ["RANK"])
  4. args.world_size = int(os.environ['WORLD_SIZE'])
  5. args.gpu = int(os.environ['LOCAL_RANK'])
  6. else:
  7. print('Not using distributed mode')
  8. return
  9. # 设置当前程序使用的GPU。根据python的机制,在单卡训练时本质上程序只使用一个CPU核心,而DataParallel
  10. # 不管用了几张GPU,依然只用一个CPU核心,在分布式训练中,每一张卡分别对应一个CPU核心,即几个GPU几个CPU核心
  11. torch.cuda.set_device(args.gpu)
  12. # 分布式初始化
  13. args.dist_url = 'env://' # 设置url
  14. args.dist_backend = 'nccl' # 通信后端,nvidia GPU推荐使用NCCL
  15. print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
  16. dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
  17. world_size=args.world_size, rank=args.rank)
  18. dist.barrier() # 等待所有进程都初始化完毕,即所有GPU都要运行到这一步以后在继续
  19. '''
  20. | distributed init (rank 1): env://
  21. | distributed init (rank 2): env://
  22. | distributed init (rank 0): env://
  23. | distributed init (rank 3): env://
  24. '''

 torch.distributed.init_process_group 是PyTorch中的一个函数,它用于初始化默认的分布式进程组,从而允许进行跨多个进程的通信。这个函数在使用 PyTorch 的分布式功能时非常重要,特别是在使用 DistributedDataParallel (DDP) 进行多GPU或多节点训练时。真正意义上来讲,分布式的初始化就只有dist.init_process_group这一句。



  • local_rank是被自动赋值的,在单机多卡中他和rank的值相同
  • os.environ[“RANK”]是没有值的,运行时在命令行上输入python -m torch.distributed.launch --nproc_per_node=4 --use_env train.py他才被赋予了值
  • –nproc_per_node=4这条指令可以将os.environ[“WORLD_SIZE”]赋值为4
  •  如果用argparse这个库,就必须加上local_rank变量,如果忘记加了,在命令行启动时就需要加上–use_env参数,–use_env 表示 Local Rank 用 LOCAL_RANK 这个环境变量传参

2.2.4 设置数据集分布式的数据集加载



  1. # 1. datasets
  2. train_datasets = MyDataSet(xxx)
  3. val_datasets = MyDataSet(xxx)
  4. # 2. DistributedSampler
  5. # 给每个rank对应的进程分配训练的样本索引,比如一共800样本8张卡,那么每张卡对应分配100个样本
  6. train_sampler = torch.utils.data.distributed.DistributedSampler(train_datasets)
  7. val_sampler = torch.utils.data.distributed.DistributedSampler(val_datasets)
  8. # 3. BatchSampler
  9. # 刚才每张卡分了100个样本,假设BatchSize=16,那么能分成100/16=6...4,即多出4个样本
  10. # 下面的drop_last=True表示舍弃这四个样本,False将剩余4个样本为一组(注意最后就不是6个一组了)
  11. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
  12. # 4. DataLoader
  13. # 验证集没有采用batchsampler,因此在dataloader中使用batch_size参数即可
  14. train_dataloader = torch.utils.data.DataLoader(train_datasets,
  15. batch_sampler=train_batch_sampler, pin_memory=True, num_workers=nw)
  16. val_dataloader = torch.utils.data.DataLoader(val_datasets,
  17. batch_size=batch_size, sampler=val_sampler, pin_memory=True, num_workers=nw)

 2.2.5 加载模型到所有GPUs上


  1. model = UNet().cuda()
  2. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
  3. ...
  4. for epoch in range(start_epoch, n_epochs):
  5. if is_distributed:
  6. train_sampler.set_epoch(epoch)
  7. ...

2.2.6 启动分布式训练

python -m torch.distributed.launch --nproc_per_node=4 --master_port=2424 --use_env main.py (your_argparse_params)

在pytorch新版中将python -m torch.distributed.launch替换为了torchrun,在训练时我们需要指定通讯端口master_port,也可以让程序自动寻找,即将--master_port=xxxx替换为--rdzv_backend c10d --master_port=0

2.3 torch.multiprocessing分部署训练步骤


torch.multiprocessing.spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn')
  • fn:这个就是我们要分布式运行的函数,一般来说是main函数,main(rank, *args),其中rank为必须,单机多卡中可以理解为第几个GPU,args为函数传入的参数,类型tuple,在spawn(…args)的args参数中定义
  • args:传入fn的参数,tuple
  • nprocs:进程数,即几张卡
  • join: 默认为True即可
  • daemon: 默认为False即可
  1. # 调用
  2. mp.spawn(main, args=(opt, ), nprocs=opt.world_size, join=True)


  1. # 单机多卡并行计算示例
  2. import os
  3. os.environ["CUDA_VISIBLE_DEVICES"] = "6, 7"
  4. import torch
  5. import torch.distributed as dist
  6. import torch.multiprocessing as mp
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. from torch.nn.parallel import DistributedDataParallel as DDP
  10. def example(rank, world_size):
  11. # create default process group
  12. dist.init_process_group("gloo", init_method='tcp://', rank=rank, world_size=world_size)
  13. # create local model
  14. model = nn.Linear(10, 10).to(rank)
  15. # construct DDP model
  16. ddp_model = DDP(model, device_ids=[rank])
  17. # define loss function and optimizer
  18. loss_fn = nn.MSELoss()
  19. optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
  20. # forward pass
  21. outputs = ddp_model(torch.randn(20, 10).to(rank))
  22. labels = torch.randn(20, 10).to(rank)
  23. # backward pass
  24. loss_fn(outputs, labels).backward()
  25. # update parameters
  26. optimizer.step()
  27. print("finished rank: {}".format(rank))
  28. def main():
  29. world_size = torch.cuda.device_count()
  30. mp.spawn(example,
  31. args=(world_size,),
  32. nprocs=world_size,
  33. join=True)
  34. if __name__=="__main__":
  35. main()

2.4 dist.barrier()函数 


  1. if args.local_rank not in [-1, 0]:
  2. torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
  3. ... (loads the model and the vocabulary)
  4. if args.local_rank == 0:
  5. torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab

假设我们有4张卡[0, 1, 2, 3],其中[0]卡是first process或者base process,有些操作不需要所有的卡同时进行,比如在预处理的时候只用base process即可。

在上述代码中,第一个if是说除了主卡之外的卡运行到此处会被barrier,也就是说运行到这里就停止了,而base process不会停止,会继续运行,执行预加载模型等操作,当主卡运行到第二个if时,他也会进入到barrier,就是说他已经预加载完了,现在他也需要被barrier了。


a process is blocked by a barrier until all processes have encountered a barrier, upon which the barrier is lifted for all processes

3 一个完整的例子

3.1 初始化进程组

  1. import os
  2. from torch import distributed
  3. try:
  4. world_size = int(os.environ["WORLD_SIZE"]) # 全局进程个数
  5. rank = int(os.environ["RANK"]) # 当前进程编号(全局)
  6. local_rank = int(os.environ["LOCAL_RANK"]) # 每台机器上的进程编号(局部)
  7. distributed.init_process_group("nccl") # 初始化进程, 使用nccl后端
  8. except KeyError:
  9. world_size = 1
  10. rank = 0
  11. local_rank = 0
  12. distributed.init_process_group(
  13. backend="nccl",
  14. init_method="tcp://",
  15. rank=rank,
  16. world_size=world_size,
  17. )

3.2 使用DistributedSampler划分数据集


  1. from dataloader.distributed_sampler import DistributedSampler
  2. train_sampler = DistributedSampler(
  3. train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed)
  4. trainloader = DataLoader(
  5. dataset=train_set,
  6. pin_memory=true,
  7. batch_size=batch_size,
  8. num_workers=num_workers,
  9. sampler=train_sampler
  10. ) # pin_memory: 是否提前申请CUDA内存. 创建DataLoader时,设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些.

3.3 使用DistributedDataParallel封装模型

DistributedDataParallel能够为不同GPU上求得的梯度进行all reduce(即汇总不同GPU计算所得的梯度,并同步计算结果)。all reduce后不同GPU中模型的梯度均为all reduce之前各GPU梯度的均值。

  1. backbone = get_model(
  2. cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
  3. backbone = torch.nn.parallel.DistributedDataParallel(
  4. module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
  5. find_unused_parameters=True)

3.4 训练模型


  1. for epoch in range(start_epoch, cfg.num_epoch):
  2. if isinstance(train_loader, DataLoader):
  3. # 设置train_loader中的sampler的epoch,DistributedSampler需要这个参数来维持各个进程之间的相同随机数种子
  4. train_loader.sampler.set_epoch(epoch)
  5. for _, (img, local_labels) in enumerate(train_loader):
  6. global_step += 1
  7. local_embeddings = backbone(img)
  8. loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
  9. loss.backward()
  10. torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
  11. opt.step()
  12. opt.zero_grad()
  13. lr_scheduler.step()

3.5 计算损失


  1. from torch import distributed
  2. distributed.all_gather(_gather_embeddings, local_embeddings)
  3. distributed.all_gather(_gather_labels, local_labels)
  4. distributed.all_reduce(loss, distributed.ReduceOp.SUM)

3.6 保存模型

  1. if rank == 0:
  2. path_module = os.path.join(cfg.output, "model_final.pt")
  3. torch.save(backbone.module.state_dict(), path_module)

3.7 启动并行程序

(1) 使用torch.distributed.launch


python -m torch.distributed.launch --nproc_per_node=8 train.py configs/ms1mv3_r50

(2) 使用torch.multiprocessing


  1. def main(rank):
  2. pass
  3. torch.multiprocessing.spawn(main, nprocs, args)

 3.8 完整代码

  1. import argparse
  2. import logging
  3. import os
  4. from datetime import datetime
  5. import numpy as np
  6. import torch
  7. from backbones import get_model
  8. from dataset import get_dataloader
  9. from losses import CombinedMarginLoss
  10. from lr_scheduler import PolyScheduler
  11. from partial_fc import PartialFC, PartialFCAdamW
  12. from torch import distributed
  13. from torch.utils.data import DataLoader
  14. from torch.utils.tensorboard import SummaryWriter
  15. from utils.utils_callbacks import CallBackLogging, CallBackVerification
  16. from utils.utils_config import get_config
  17. from utils.utils_distributed_sampler import setup_seed
  18. from utils.utils_logging import AverageMeter, init_logging
  19. assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \
  20. we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
  21. try:
  22. rank = int(os.environ["RANK"])
  23. local_rank = int(os.environ["LOCAL_RANK"])
  24. world_size = int(os.environ["WORLD_SIZE"])
  25. distributed.init_process_group("nccl")
  26. except KeyError:
  27. rank = 0
  28. local_rank = 0
  29. world_size = 1
  30. distributed.init_process_group(
  31. backend="nccl",
  32. init_method="tcp://",
  33. rank=rank,
  34. world_size=world_size,
  35. )
  36. def main(args):
  37. # get config
  38. cfg = get_config(args.config)
  39. # global control random seed
  40. setup_seed(seed=cfg.seed, cuda_deterministic=False)
  41. torch.cuda.set_device(local_rank)
  42. os.makedirs(cfg.output, exist_ok=True)
  43. init_logging(rank, cfg.output)
  44. summary_writer = (
  45. SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
  46. if rank == 0
  47. else None
  48. )
  49. wandb_logger = None
  50. if cfg.using_wandb:
  51. import wandb
  52. # Sign in to wandb
  53. try:
  54. wandb.login(key=cfg.wandb_key)
  55. except Exception as e:
  56. print("WandB Key must be provided in config file (base.py).")
  57. print(f"Config Error: {e}")
  58. # Initialize wandb
  59. run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
  60. run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
  61. try:
  62. wandb_logger = wandb.init(
  63. entity = cfg.wandb_entity,
  64. project = cfg.wandb_project,
  65. sync_tensorboard = True,
  66. resume=cfg.wandb_resume,
  67. name = run_name,
  68. notes = cfg.notes) if rank == 0 or cfg.wandb_log_all else None
  69. if wandb_logger:
  70. wandb_logger.config.update(cfg)
  71. except Exception as e:
  72. print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
  73. print(f"Config Error: {e}")
  74. train_loader = get_dataloader(
  75. cfg.rec,
  76. local_rank,
  77. cfg.batch_size,
  78. cfg.dali,
  79. cfg.seed,
  80. cfg.num_workers
  81. )
  82. backbone = get_model(
  83. cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
  84. backbone = torch.nn.parallel.DistributedDataParallel(
  85. module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
  86. find_unused_parameters=True)
  87. backbone.train()
  88. # FIXME using gradient checkpoint if there are some unused parameters will cause error
  89. backbone._set_static_graph()
  90. margin_loss = CombinedMarginLoss(
  91. 64,
  92. cfg.margin_list[0],
  93. cfg.margin_list[1],
  94. cfg.margin_list[2],
  95. cfg.interclass_filtering_threshold
  96. )
  97. if cfg.optimizer == "sgd":
  98. module_partial_fc = PartialFC(
  99. margin_loss, cfg.embedding_size, cfg.num_classes,
  100. cfg.sample_rate, cfg.fp16)
  101. module_partial_fc.train().cuda()
  102. # TODO the params of partial fc must be last in the params list
  103. opt = torch.optim.SGD(
  104. params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
  105. lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)
  106. elif cfg.optimizer == "adamw":
  107. module_partial_fc = PartialFCAdamW(
  108. margin_loss, cfg.embedding_size, cfg.num_classes,
  109. cfg.sample_rate, cfg.fp16)
  110. module_partial_fc.train().cuda()
  111. opt = torch.optim.AdamW(
  112. params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
  113. lr=cfg.lr, weight_decay=cfg.weight_decay)
  114. else:
  115. raise
  116. cfg.total_batch_size = cfg.batch_size * world_size
  117. cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
  118. cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
  119. lr_scheduler = PolyScheduler(
  120. optimizer=opt,
  121. base_lr=cfg.lr,
  122. max_steps=cfg.total_step,
  123. warmup_steps=cfg.warmup_step,
  124. last_epoch=-1
  125. )
  126. start_epoch = 0
  127. global_step = 0
  128. if cfg.resume:
  129. dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
  130. start_epoch = dict_checkpoint["epoch"]
  131. global_step = dict_checkpoint["global_step"]
  132. backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
  133. module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
  134. opt.load_state_dict(dict_checkpoint["state_optimizer"])
  135. lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
  136. del dict_checkpoint
  137. for key, value in cfg.items():
  138. num_space = 25 - len(key)
  139. logging.info(": " + key + " " * num_space + str(value))
  140. callback_verification = CallBackVerification(
  141. val_targets=cfg.val_targets, rec_prefix=cfg.rec,
  142. summary_writer=summary_writer, wandb_logger = wandb_logger
  143. )
  144. callback_logging = CallBackLogging(
  145. frequent=cfg.frequent,
  146. total_step=cfg.total_step,
  147. batch_size=cfg.batch_size,
  148. start_step = global_step,
  149. writer=summary_writer
  150. )
  151. loss_am = AverageMeter()
  152. amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
  153. for epoch in range(start_epoch, cfg.num_epoch):
  154. if isinstance(train_loader, DataLoader):
  155. train_loader.sampler.set_epoch(epoch)
  156. for _, (img, local_labels) in enumerate(train_loader):
  157. global_step += 1
  158. local_embeddings = backbone(img)
  159. loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
  160. if cfg.fp16:
  161. amp.scale(loss).backward()
  162. amp.unscale_(opt)
  163. torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
  164. amp.step(opt)
  165. amp.update()
  166. else:
  167. loss.backward()
  168. torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
  169. opt.step()
  170. opt.zero_grad()
  171. lr_scheduler.step()
  172. with torch.no_grad():
  173. if wandb_logger:
  174. wandb_logger.log({
  175. 'Loss/Step Loss': loss.item(),
  176. 'Loss/Train Loss': loss_am.avg,
  177. 'Process/Step': global_step,
  178. 'Process/Epoch': epoch
  179. })
  180. loss_am.update(loss.item(), 1)
  181. callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
  182. if global_step % cfg.verbose == 0 and global_step > 0:
  183. callback_verification(global_step, backbone)
  184. if cfg.save_all_states:
  185. checkpoint = {
  186. "epoch": epoch + 1,
  187. "global_step": global_step,
  188. "state_dict_backbone": backbone.module.state_dict(),
  189. "state_dict_softmax_fc": module_partial_fc.state_dict(),
  190. "state_optimizer": opt.state_dict(),
  191. "state_lr_scheduler": lr_scheduler.state_dict()
  192. }
  193. torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
  194. if rank == 0:
  195. path_module = os.path.join(cfg.output, "model.pt")
  196. torch.save(backbone.module.state_dict(), path_module)
  197. if wandb_logger and cfg.save_artifacts:
  198. artifact_name = f"{run_name}_E{epoch}"
  199. model = wandb.Artifact(artifact_name, type='model')
  200. model.add_file(path_module)
  201. wandb_logger.log_artifact(model)
  202. if cfg.dali:
  203. train_loader.reset()
  204. if rank == 0:
  205. path_module = os.path.join(cfg.output, "model.pt")
  206. torch.save(backbone.module.state_dict(), path_module)
  207. from torch2onnx import convert_onnx
  208. convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))
  209. if wandb_logger and cfg.save_artifacts:
  210. artifact_name = f"{run_name}_Final"
  211. model = wandb.Artifact(artifact_name, type='model')
  212. model.add_file(path_module)
  213. wandb_logger.log_artifact(model)
  214. distributed.destroy_process_group()
  215. if __name__ == "__main__":
  216. torch.backends.cudnn.benchmark = True
  217. parser = argparse.ArgumentParser(
  218. description="Distributed Arcface Training in Pytorch")
  219. parser.add_argument("config", type=str, help="py config file")
  220. main(parser.parse_args())

4 分布式训练可能遇到的问题

4.1 runtimeerror: address already in use


4.2 调试时可能会出现的问题

  • 显存未释放:nvidia-smi看一下显存是否释放,如果没有释放使用kill -9 PID命令进行释放。如果kill也无法释放显存,直接将terminal关闭重新开一个即可
  • 端口被占用:如果第一次调试后进行第二次调试时提示xx端口被占用了,这里最快的解决方法时将当前terminal关闭,然后重新开一个即可,或者参考第一个问题,kill掉相应的PID






