赞
踩
tee pp_demo.py <<-'EOF' import os import torch from torch import nn import torch.nn.functional as F import numpy as np import torch.distributed as dist from torch.distributed import ReduceOp import time import argparse parser = argparse.ArgumentParser(description="") parser.add_argument('--hidden_size', default=512, type=int, help='') parser.add_argument('--ffn_size', default=1024, type=int, help='') parser.add_argument('--seq_len', default=512, type=int, help='') parser.add_argument('--batch_size', default=8, type=int, help='') parser.add_argument('--world_size', default=4, type=int, help='') parser.add_argument('--device', default="cuda", type=str, help='') parser.add_argument('--chunk_size', default=1, type=int, help='') class FeedForward(nn.Module): def __init__(self,hidden_size,ffn_size): super(FeedForward, self).__init__() self.fc1 = nn.Linear(hidden_size, ffn_size,bias=False) self.fc2 = nn.Linear(ffn_size, hidden_size,bias=False) def forward(self, input): return self.fc2(self.fc1(input)) args = parser.parse_args() hidden_size = args.hidden_size ffn_size = args.ffn_size seq_len = args.seq_len batch_size = args.batch_size world_size = args.world_size device = args.device chunk_size = args.chunk_size def tp_mode(): torch.random.manual_seed(1) dist.init_process_group(backend='nccl') world_size = torch.distributed.get_world_size() rank=rank = torch.distributed.get_rank() local_rank=int(os.environ['LOCAL_RANK']) torch.cuda.set_device(local_rank) device = torch.device("cuda",local_rank) model = FeedForward(hidden_size,ffn_size) model.eval() input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device) model=model.half().to(device) index=0 count=0 t0=0 chunks=torch.split(input,chunk_size,dim=0) for epoch in range(32): index+=1 if index>1: count+=1 if t0==0: t0=time.time() if count%10==0 and rank==0: print("qps:{:.2f}".format(count/(time.time()-t0))) count=0 t0=0 all_output=[] snd_reqs=[] for chunk in chunks: if rank==0: out=model(chunk) else: torch.distributed.recv(chunk,rank-1) out=model(chunk) if rank==world_size-1: all_output.append(out.clone()) else: snd_reqs = torch.distributed.send(out,rank+1) if rank==world_size-1: out=torch.cat(all_output,dim=0) if __name__ == "__main__": num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 is_distributed = num_gpus > 1 if is_distributed: tp_mode() EOF torchrun -m --nnodes=1 --nproc_per_node=4 pp_demo \ --hidden_size 512 --ffn_size 4096 --seq_len 512 \ --batch_size 16 --world_size 4 --chunk_size 8
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。