赞
踩
Andrej rebuild gpt2 in pytorch.
Take Away Points
# torchrun --stand_alone --nproc_per_node=<num_gpu_per_node> <your_training_script.py> <script_arguments>
# Above only applies for single node training.
# SETTINGS FOR EACH DIFFERENT RANK
ddp = int(os.environ.get('RANK',-1))!=-1
if ddp:
assert torch.cuda.is_available()
init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK']) # It is a global rank, for each process it has a unique ddp_rank
ddp_local_rank = int(os.environ['LOCAL_RANK']) # It is a local rank in the local machine (node)
ddp_world_size = int(os.environ['WORLD_SIZE']) # How many gpus (processes) in total
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
master_process = ddp_rank == 0
else:
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
device = "cpu"
if torhc.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends,"mps") and torch.bakends.mps.is_available():
device = "mps"
print(f"using device:{device}")
# IF YOU USE GRAD ACCUMULATION
total_batch_size = 524288 # batch size measured in token numbers
B = 16 # micro batch for each process
T = 1024 # sequence length
assert total_batch%(B * T * ddp_world_size) == 0
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
# SET DATALOADER
Dataloader = DataLoader(*args, ddp_world_size, ddp_rank) # MUST! make each process deal with different part of datset
# CREATE MODEL
model = createmodel()
model.to(device)
model = torch.compile(model)
if ddp:
model = DDP(model,device_ids=[ddp_local_rank]) # this must be ddp_local_rank not ddp_rank
raw_model = model.module if ddp else model
# FIX SEED
seed = 'YOUR LUCKY NUMBER'
torch.mannual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# TRAIN
for step in range(max_steps):
t0 = time.time()
model.train()
optimizer.zero_grad()
loss_accum = 0.0
for micro_step in range(grad_accum_steps):
x,y = Dataloader.next_batch()
x,y = x.to(device),y.to(device)
with torch.autocast(device_type=device,dtype=torch.bfloat16):
logits, loss = model(x,y)
loss = loss / grad_accum_steps
loss_accum += loss.detach()
if ddp:
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) # The ddp sync if applied to every micro step will be wasting time. So only the last backward in one accum cycle should be synchronized. See ddp.no_sync() contextmanager for official advice. Or use it in this way shown here.
loss.backward()
if ddp:
torch.distributed.all_reduce(loss_accum,op=torch.distributed.ReduceOp.AVG)
norm = torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
if step%100 == 0:
# start evaluation
model.eval()
with torch.no_grad():
# SOME EVALUATION CODE
if ddp:
destroy_process_group()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。