torch.nn.DataParallel ==> 简称 DP

torch.nn.parallel.DistributedDataParallel ==> 简称DDP



为了让模型更快train完,把长度相近的文本打包成一个batch(温馨提醒,torchtext也有相关的类 bucketiterator[1],大概形式如下:

  1. class BucketSampler(torch.utils.data.Sampler):
  2. def __init__(self, data_source, batch_size=32):
  3. self.data_source = data_source
  4. self.batch_size = batch_size
  5. def __iter__(self):
  6. idxs, lens, batch, middle_batch_size, long_batch_size = basesampler(self.data_source , self.batch_size)
  7. for idx in idxs:
  8. batch.append(idx)
  9. mlen = max([0]+[lens[x] for x in batch])
  10. #if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32:
  11. if (mlen<100 and len(batch) == self.batch_size) or (mlen>100 and mlen<220 and len(batch) >= middle_batch_size) or (mlen>220 and len(batch)>=long_batch_size) or len(batch)==self.batch_size:
  12. yield batch
  13. batch = []
  14. if len(batch) > 0:
  15. yield batch
  16. def __len__(self):
  17. return (len(self.data_source)+self.batch_size-1)//self.batch_size




class DDPBaseBucketSampler(torch.utils.data.distributed.DistributedSampler):


  • dataloader不会发包;

  • dataloader给每个进程发的是完整的数据,按武德来说,应该是1/n的数据,n为你设置的gpu数量;


  1. def __iter__(self) -> Iterator[T_co]:
  2. if self.shuffle:
  3. # deterministically shuffle based on epoch and seed
  4. g = torch.Generator()
  5. g.manual_seed(self.seed + self.epoch)
  6. indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore
  7. else:
  8. indices = list(range(len(self.dataset))) # type: ignore
  9. if not self.drop_last:
  10. # add extra samples to make it evenly divisible
  11. padding_size = self.total_size - len(indices)
  12. if padding_size <= len(indices):
  13. indices += indices[:padding_size]
  14. else:
  15. indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
  16. else:
  17. # remove tail of data to make it evenly divisible.
  18. indices = indices[:self.total_size]
  19. assert len(indices) == self.total_size
  20. # subsample
  21. indices = indices[self.rank:self.total_size:self.num_replicas] # 这一步保证每个进程拿到的数据不同
  22. assert len(indices) == self.num_samples
  23. return iter(indices)


  1. def __iter__(self) -> Iterator[T_co]:
  2. raise NotImplementedError



  1. def basesampler(lens, indices, batch_size):
  2. # the magic number comes from the author's code
  3. t1 = []
  4. t2 = []
  5. t3 = []
  6. for i, l in enumerate(lens):
  7. if (l<100):
  8. t1.append(indices[i])
  9. elif (l>100 and l<220):
  10. t2.append(indices[i])
  11. else:
  12. t3.append(indices[i])
  13. datas = [t1,t2,t3]
  14. random.shuffle(datas)
  15. idxs = sum(datas, [])
  16. batch = []
  17. #为了保证不爆卡,我们给不同长度的数据上保护锁
  18. middle_batch_size = min(int(batch_size * 0.75) , 32)
  19. long_batch_size = min(int(batch_size * 0.5) , 24)
  20. return idxs, batch, middle_batch_size, long_batch_size
  21. class DDPBaseBucketSampler(torch.utils.data.distributed.DistributedSampler):
  22. '''
  23. 这里要注意和单GPU的sampler类同步
  24. '''
  25. def __init__(self, dataset, num_replicas, rank, shuffle=True, batch_size=32):
  26. super(DDPBaseBucketSampler, self).__init__(dataset, num_replicas, rank, shuffle)
  27. self.batch_size = batch_size
  28. def __iter__(self):
  29. # deterministically shuffle based on epoch
  30. g = torch.Generator()
  31. g.manual_seed(self.epoch)
  32. #print('here is pytorch code and you can delete it in the /home/lzk/anaconda3/lib/python3.7/site-packages/torch/utils/data')
  33. if self.shuffle:
  34. indices = torch.randperm(len(self.dataset), generator=g).tolist()
  35. else:
  36. indices = list(range(len(self.dataset)))
  37. # add extra samples to make it evenly divisible
  38. indices += indices[:(self.total_size - len(indices))]
  39. assert len(indices) == self.total_size
  40. indices = indices[self.rank:self.total_size:self.num_replicas]
  41. assert len(indices) == self.num_samples
  42. # 然后我也要拿到每个数据的长度 (每个rank不同)
  43. lens = torch.Tensor([len(x) for x in self.dataset])
  44. idxs, batch, middle_batch_size, long_batch_size = basesampler(lens[indices], indices, self.batch_size)
  45. for idx in idxs:
  46. batch.append(idx)
  47. mlen = max([0]+[lens[x] for x in batch])
  48. #if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32:
  49. if (mlen<100 and len(batch) == self.batch_size) or (mlen>100 and mlen<220 and len(batch) >= middle_batch_size) or (mlen>220 and len(batch)>=long_batch_size) or len(batch)==self.batch_size:
  50. yield batch
  51. batch = []
  52. # print('应该出现2次如果是2个进程的话')
  53. if len(batch) > 0:
  54. yield batch
  55. def __len__(self):
  56. return (len(self.dataset)+self.batch_size-1)//self.batch_size




number workers ddp pytorch下无法正常结束。具体表现为,mp.spawn传递的函数参数可以顺利运行完,但是master进程一直占着卡,不退出。一开始我怀疑是sampler函数的分发batch的机制导致的,什么意思呢?就是由于每个进程拿到的数据不一样,各自进程执行sampler类的时候,由于我规定了长度接近的文本打包在一起,所以可能master进程有一百个iter,slave只有80个,然后我马上试了一下,很快啊:





  1. if args.is_ddp:
  2. dist.destroy_process_group()
  3. print('rank destroy_process_group: ' , rank)


  1. File "train.py", line 322, in <module>
  2. main(args.gpu, args)
  3. File "/home/lzk/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
  4. while not spawn_context.join():
  5. File "/home/lzk/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 77, in join
  6. timeout=timeout,
  7. File "/home/lzk/anaconda3/lib/python3.7/multiprocessing/connection.py", line 920, in wait
  8. ready = selector.select(timeout)
  9. File "/home/lzk/anaconda3/lib/python3.7/selectors.py", line 415, in select
  10. fd_event_list = self._selector.poll(timeout)
  11. TypeError: keyboard_interrupt_handler() takes 1 positional argument but 2 were given
  12. ^CError in atexit._run_exitfuncs:
  13. Traceback (most recent call last):
  14. File "/home/lzk/anaconda3/lib/python3.7/multiprocessing/popen_fork.py", line 28, in poll
  15. pid, sts = os.waitpid(self.pid, flag)
  16. TypeError: keyboard_interrupt_handler() takes 1 positional argument but 2 were given

代码参考:基于Python初探Linux下的僵尸进程和孤儿进程(三)[3]、 Multiprocessing in python blocked[4]

很显然是pytorch master进程产生死锁了,变成了僵尸进程。

再探究,发现当我把dataloader的number workers设为0的时候,程序可以正常结束。经过我的注释大法后我发现,哪怕我把for _i , batch in enumerate(dataloader)内的代码全部注释改为pass,程序还是会出现master无法正常结束的情况。所以问题锁定在dataloader身上。参考:nero:PyTorch DataLoader初探[5]

另外一种想法是,mp.spawn出现了问题。使用此方式启动的进程,只会执行和 target 参数或者 run() 方法相关的代码。Windows 平台只能使用此方法,事实上该平台默认使用的也是该启动方式。相比其他两种方式,此方式启动进程的效率最低。参考:Python设置进程启动的3种方式[6]


python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="" --master_port=23456 我的文件.py


  • nnodes:因为是单机多卡,所以设为1,显然node_rank 只能是0了

  • local_rank:进程在运行的时候,会利用args插入local_rank这个参数标识进程序号





顺着报错路径去torch/distributed/launch.py, line 239找代码:

  1. def main():
  2. args = parse_args()
  3. # world size in terms of number of processes
  4. dist_world_size = args.nproc_per_node * args.nnodes
  5. # set PyTorch distributed related environmental variables
  6. current_env = os.environ.copy()
  7. current_env["MASTER_ADDR"] = args.master_addr
  8. current_env["MASTER_PORT"] = str(args.master_port)
  9. current_env["WORLD_SIZE"] = str(dist_world_size)
  10. processes = []
  11. if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
  12. current_env["OMP_NUM_THREADS"] = str(1)
  13. print("*****************************************\n"
  14. "Setting OMP_NUM_THREADS environment variable for each process "
  15. "to be {} in default, to avoid your system being overloaded, "
  16. "please further tune the variable for optimal performance in "
  17. "your application as needed. \n"
  18. "*****************************************".format(current_env["OMP_NUM_THREADS"]))
  19. for local_rank in range(0, args.nproc_per_node):
  20. # each process's rank
  21. dist_rank = args.nproc_per_node * args.node_rank + local_rank
  22. current_env["RANK"] = str(dist_rank)
  23. current_env["LOCAL_RANK"] = str(local_rank)
  24. # spawn the processes
  25. if args.use_env:
  26. cmd = [sys.executable, "-u",
  27. args.training_script] + args.training_script_args
  28. else:
  29. cmd = [sys.executable,
  30. "-u",
  31. args.training_script,
  32. "--local_rank={}".format(local_rank)] + args.training_script_args
  33. process = subprocess.Popen(cmd, env=current_env)
  34. processes.append(process)
  35. for process in processes:
  36. process.wait() # 等待运行结束
  37. if process.returncode != 0:
  38. raise subprocess.CalledProcessError(returncode=process.returncode,
  39. cmd=cmd)



  1. def main(args):
  2. ############################################################
  3. print('local_rank : ' , args.local_rank )
  4. if args.is_ddp:
  5. dist.init_process_group(
  6. backend='nccl',
  7. init_method='env://',
  8. world_size=args.world_size,
  9. rank=args.local_rank
  10. )
  11. ############################################################
  12. # torch.multiprocessing.set_sharing_strategy('file_system') 万恶之源
  13. os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"].split(',')[args.local_rank]
  14. args.device = torch.device(0)
  15. ...

为什么我当时会加上这句话呢?因为当时在调试number worker的时候(当时年轻,以为越大越好,所以设置成了number workers = cpu.count()),发现系统报错,说超出了打开文件的最大数量限制。在torch.multiprocessing的设定里,共享策略(参考pytorch中文文档[7])默认是File descriptor,此策略将使用文件描述符作为共享内存句柄。当存储被移动到共享内存中,一个由shm_open获得的文件描述符被缓存。当时,文档还提到:





