当前位置:   article > 正文

pytorch演示pipeline并行

pytorch演示pipeline并行

pytorch演示pipeline并行


1.单卡内存不够时,可以将网络切分成几段(stage),每个GPU负责一个stage。比如GPU0计算完之后将数据发送给GPU1算后续的stage
2.以上的方式,会导致GPU的利用率不高,可以将输入的batch切分成多份更小的batch,陆续送给GPU0,这样GPU0处理完micro batch0之后 可以处理micro batch1.如此便能提高GPU的利用率
在这里插入图片描述

tee pp_demo.py <<-'EOF'
import os
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import torch.distributed as dist
from torch.distributed import ReduceOp   
import time
import argparse
  
parser = argparse.ArgumentParser(description="")
parser.add_argument('--hidden_size', default=512, type=int, help='')
parser.add_argument('--ffn_size', default=1024, type=int, help='')
parser.add_argument('--seq_len', default=512, type=int, help='')
parser.add_argument('--batch_size', default=8, type=int, help='')
parser.add_argument('--world_size', default=4, type=int, help='')
parser.add_argument('--device', default="cuda", type=str, help='')
parser.add_argument('--chunk_size', default=1, type=int, help='')
 
class FeedForward(nn.Module):
  
    def __init__(self,hidden_size,ffn_size):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(hidden_size, ffn_size,bias=False)
        self.fc2 = nn.Linear(ffn_size, hidden_size,bias=False)
  
    def forward(self, input):
        return self.fc2(self.fc1(input))
 
args = parser.parse_args()
hidden_size = args.hidden_size
ffn_size = args.ffn_size
seq_len = args.seq_len
batch_size = args.batch_size
world_size = args.world_size
device = args.device
chunk_size = args.chunk_size
 
def tp_mode():
  torch.random.manual_seed(1)
  dist.init_process_group(backend='nccl')
      
  world_size = torch.distributed.get_world_size()
  rank=rank = torch.distributed.get_rank()
  local_rank=int(os.environ['LOCAL_RANK'])
    
  torch.cuda.set_device(local_rank)
  device = torch.device("cuda",local_rank)
 
  model = FeedForward(hidden_size,ffn_size)
  model.eval()
  input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)
  model=model.half().to(device)
 
  index=0
  count=0
  t0=0
 
  chunks=torch.split(input,chunk_size,dim=0)
 
  for epoch in range(32):
    index+=1
    if index>1:
      count+=1
      if t0==0:
        t0=time.time()
      if count%10==0 and rank==0:
        print("qps:{:.2f}".format(count/(time.time()-t0)))
        count=0
        t0=0
 
    all_output=[]
    snd_reqs=[]
 
    for chunk in chunks: 
      if rank==0:
        out=model(chunk)
      else:
        torch.distributed.recv(chunk,rank-1)
        out=model(chunk)
      if rank==world_size-1:
        all_output.append(out.clone())
      else:
        snd_reqs = torch.distributed.send(out,rank+1)
 
    if rank==world_size-1:
      out=torch.cat(all_output,dim=0)
 
if __name__ == "__main__":
  num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
  is_distributed = num_gpus > 1
  if is_distributed:
    tp_mode()
EOF
 
torchrun -m --nnodes=1 --nproc_per_node=4 pp_demo \
		--hidden_size 512 --ffn_size 4096 --seq_len 512 \
		--batch_size 16 --world_size 4 --chunk_size 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/391989
推荐阅读
相关标签
  

闽ICP备14008679号