赞
踩
之前拦截torch和torch.Tensor的办法,在处理backward时,不能看到aten算子的细节.以下基于__torch_dispatch__机制的方案更节约代码,且能看到调用栈
[原理] (https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557)
import os os.environ['CUDA_LAUNCH_BLOCKING'] = '1' import torch from torch import nn import math import torch.nn.functional as F from torch.autograd import Variable import time import os import threading device="cuda" from torch.utils._python_dispatch import TorchDispatchMode import inspect import traceback from dataclasses import dataclass from typing import Any @dataclass class _ProfilerState: cls: Any object: Any = None lock=threading.Lock() gindex=0 def save_tensor(name,args,index=0): if isinstance(args,torch.Tensor): print(name,index,args.shape) global gindex lock.acquire() torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index)) gindex+=1 lock.release() if isinstance(args,tuple): for idx,x in enumerate(args): save_tensor(name,x,index+idx) class TorchDumpDispatchMode(TorchDispatchMode): def __init__(self,parent): super().__init__() self.parent=parent def __torch_dispatch__(self, func, types, args=(), kwargs=None): func_packet = func._overloadpacket if kwargs is None: kwargs = {} enable_dump=False if func_packet.__name__ not in ["detach"]: enable_dump=True print(f"Profiling {func_packet.__name__}") for idx,stack in enumerate(inspect.stack()): print(f'{"*"*idx}{stack.filename}{stack.lineno}') if enable_dump: save_tensor(f"{func_packet.__name__}-input",args) ret= func(*args, **kwargs) if enable_dump: save_tensor(f"{func_packet.__name__}-output",ret) return ret class TorchDumper: _CURRENT_Dumper = None def __init__(self,schedule: Any): self.p= _ProfilerState(schedule) def __enter__(self): assert TorchDumper._CURRENT_Dumper is None TorchDumper._CURRENT_Dumper = self if self.p.object is None: o = self.p.cls(self) o.__enter__() self.p.object = o else: self.p.object.step() return self def __exit__(self, exc_type, exc_val, exc_tb): TorchDumper._CURRENT_Dumper = None if self.p.object is not None: self.p.object.__exit__(exc_type, exc_val, exc_tb) class Attention(nn.Module): def __init__(self,max_seq_len,head_dim,flash): super().__init__() self.flash = flash self.dropout=0 self.attn_dropout = nn.Dropout(self.dropout) self.head_dim=head_dim if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device) mask = torch.triu(mask, diagonal=1).half().to(device) self.register_buffer("mask", mask) def forward( self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor): if self.flash: output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: _xk=xk.clone() t=_xk.transpose(2, 3) scores = torch.matmul(xq,t) scores = scores/math.sqrt(self.head_dim) a=self.mask[:, :, :seqlen, :seqlen] scores = scores+a scores = F.softmax(scores.float(), dim=-1) scores = scores.type_as(xq) scores = self.attn_dropout(scores) output = torch.matmul(scores, xv) return output def main(flash,bs, n_local_heads, seqlen, head_dim): torch.random.manual_seed(1) q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device) k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device) v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device) q.data.normal_(0, 0.1) k.data.normal_(0, 0.1) v.data.normal_(0, 0.1) q=Variable(q, requires_grad=True).to(device) k=Variable(k, requires_grad=True).to(device) v=Variable(v, requires_grad=True).to(device) gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device) loss_func=nn.CrossEntropyLoss().to(device) model=Attention(seqlen,head_dim,flash).half().to(device) optim = torch.optim.SGD([q,k,v], lr=1.1) with TorchDumper(TorchDumpDispatchMode): for i in range(1): output = model(q,k,v) loss=loss_func(output.reshape(-1,head_dim),gt) loss.backward() optim.step() print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item())) bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64 main(False,bs, n_local_heads, seqlen, head_dim)
Profiling clone /home/user/proj/attention/attention_torch_dispatch_dumper.py60 */home/user/proj/attention/attention_torch_dispatch_dumper.py109 **/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527 ***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518 ****/home/user/proj/attention/attention_torch_dispatch_dumper.py144 *****/home/user/proj/attention/attention_torch_dispatch_dumper.py151 clone-input 0 torch.Size([8, 8, 512, 64]) clone-output 0 torch.Size([8, 8, 512, 64]) Profiling transpose /home/user/proj/attention/attention_torch_dispatch_dumper.py60 */home/user/proj/attention/attention_torch_dispatch_dumper.py110 **/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527 ***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518 ****/home/user/proj/attention/attention_torch_dispatch_dumper.py144 *****/home/user/proj/attention/attention_torch_dispatch_dumper.py151 transpose-input 0 torch.Size([8, 8, 512, 64]) transpose-output 0 torch.Size([8, 8, 512, 64]) Profiling expand /home/user/proj/attention/attention_torch_dispatch_dumper.py60 */home/user/proj/attention/attention_torch_dispatch_dumper.py111 **/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527 ***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518 ****/home/user/proj/attention/attention_torch_dispatch_dumper.py144 *****/home/user/proj/attention/attention_torch_dispatch_dumper.py151 expand-input 0 torch.Size([8, 8, 512, 64]) expand-output 0 torch.Size([8, 8, 512, 64]) Profiling view /home/user/proj/attention/attention_torch_dispatch_dumper.py60 */home/user/proj/attention/attention_torch_dispatch_dumper.py111 **/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527 ***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518 ****/home/user/proj/attention/attention_torch_dispatch_dumper.py144 *****/home/user/proj/attention/attention_torch_dispatch_dumper.py151 view-input 0 torch.Size([8, 8, 512, 64]) view-output 0 torch.Size([8, 8, 512, 64]) Profiling expand /home/user/proj/attention/attention_torch_dispatch_dumper.py60 */home/user/proj/attention/attention_torch_dispatch_dumper.py111 **/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527 ***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518 ****/home/user/proj/attention/attention_torch_dispatch_dumper.py144 *****/home/user/proj/attention/attention_torch_dispatch_dumper.py151 expand-input 0 torch.Size([8, 8, 64, 512]) expand-output 0 torch.Size([8, 8, 64, 512]) Profiling view /home/user/proj/attention/attention_torch_dispatch_dumper.py60 */home/user/proj/attention/attention_torch_dispatch_dumper.py111 **/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527 ***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518 ****/home/user/proj/attention/attention_torch_dispatch_dumper.py144 *****/home/user/proj/attention/attention_torch_dispatch_dumper.py151 view-input 0 torch.Size([8, 8, 64, 512]) view-output 0 torch.Size([8, 8, 64, 512]) Profiling bmm /home/user/proj/attention/attention_torch_dispatch_dumper.py60 */home/user/proj/attention/attention_torch_dispatch_dumper.py111 **/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527 ***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518 ****/home/user/proj/attention/attention_torch_dispatch_dumper.py144 *****/home/user/proj/attention/attention_torch_dispatch_dumper.py151 bmm-input 0 torch.Size([64, 512, 64]) bmm-input 1 torch.Size([64, 64, 512]) bmm-output 0 torch.Size([64, 512, 64]) bmm-output 1 torch.Size([64, 64, 512]) Profiling _unsafe_view
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。