赞
踩
flash attention实现:
import torch from xformers import ops as xops import time bs = 32 seq_len = 512 n_head = 16 head_dim = 64 query_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0") key_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0") value_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0") flash_query_states = query_states.transpose(1, 2) flash_key_states = key_states.transpose(1, 2) flash_value_states = value_states.transpose(1, 2) start_time = time.time() #xformers 实现的注意力机制, 加速框架 flash_attn_output = xops.memory_efficient_attention( flash_query_states, flash_key_states, flash_value_states, attn_bias=xops.LowerTriangularMask() ) print(f'flash attention time: {(time.time()-start_time)*1000} ms') print(torch.cuda.max_memory_allocated("cuda:0")/1024**2) #192M print("=============================") print(torch.cuda.memory_allocated("cuda:0")/1024**2) #128M
standard attention 实现:
import torch from xformers import ops as xops import time bs = 32 seq_len = 512 n_head = 16 head_dim = 64 query_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0") key_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0") value_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0") flash_query_states = query_states.transpose(1, 2) flash_key_states = key_states.transpose(1, 2) flash_value_states = value_states.transpose(1, 2) start_time = time.time() import math import torch.nn as nn attention_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool)).view(1, 1, seq_len, seq_len) attention_mask = attention_mask.to(dtype=torch.float16).cuda() # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float16).min #数据类型 def standard_attention(query_states, key_states, value_states, attention_mask): attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) return attn_output start_time = time.time() attn_output = standard_attention(query_states, key_states, value_states, attention_mask) print(f'standard attention time: {(time.time()-start_time)*1000} ms') #print(torch.allclose(attn_output, flash_attn_output, rtol=2e-3, atol=2e-3)) #判断两个张量是否接近相等(计算机计算的不精确性,完全相等的浮点数可能存在微小差异) print(torch.cuda.max_memory_allocated("cuda:0")/1024**2) #1128M print("=============================") print(torch.cuda.memory_allocated("cuda:0")/1024**2) #136M
flash attention1 实现:
import torch torch.manual_seed(456) N, d = 16, 8 Q_mat = torch.rand((N, d)) K_mat = torch.rand((N, d)) V_mat = torch.rand((N, d)) # 执行标准的pytorch softmax和attention计算 expected_softmax = torch.softmax(Q_mat @ K_mat.T, dim=1) expected_attention = expected_softmax @ V_mat # 分块(tiling)尺寸,以SRAM的大小计算得到 Br = 4 Bc = d # flash attention算法流程的第2步,首先在HBM中创建用于存储输出结果的O,全部初始化为0 O = torch.zeros((N, d)) # flash attention算法流程的第2步,用来存储softmax的分母值,在HBM中创建 l = torch.zeros((N, 1)) # flash attention算法流程的第2步,用来存储每个block的最大值,在HBM中创建 m = torch.full((N, 1), -torch.inf) # 算法流程的第5步,执行外循环 for block_start_Bc in range(0, N, Bc): block_end_Bc = block_start_Bc + Bc # line 6, load a block from matmul input tensor # 算法流程第6步,从HBM中load Kj, Vj的一个block到SRAM Kj = K_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d Vj = V_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d # 算法流程第7步,执行内循环 for block_start_Br in range(0, N, Br): block_end_Br = block_start_Br + Br # 算法流程第8行,从HBM中分别load以下几项到SRAM中 mi = m[block_start_Br:block_end_Br, :] # shape Br x 1 li = l[block_start_Br:block_end_Br, :] # shape Br x 1 Oi = O[block_start_Br:block_end_Br, :] # shape Br x d Qi = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d # 算法流程第9行 Sij = Qi @ Kj.T # shape Br x Bc # 算法流程第10行,计算当前block每行的最大值 mij_hat = torch.max(Sij, dim=1).values[:, None] # 算法流程第10行,计算softmax的分母 pij_hat = torch.exp(Sij - mij_hat) lij_hat = torch.sum(pij_hat, dim=1)[:, None] # 算法流程第11行,找到当前block的每行最大值以及之前的最大值 mi_new = torch.max(torch.column_stack([mi, mij_hat]), dim=1).values[:, None] # 算法流程第11行,计算softmax的分母,但是带了online计算的校正,此公式与前面说的online safe softmax不一致,但是是同样的数学表达式,只是从针对标量的逐个计算扩展到了针对逐个向量的计算 li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat # 算法流程第12行,计算每个block的输出值 Oi = (li * torch.exp(mi - mi_new) * Oi / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj # 算法流程第13行 m[block_start_Br:block_end_Br, :] = mi_new # row max l[block_start_Br:block_end_Br, :] = li_new # softmax denominator # 算法流程第12行,将Oi再写回到HBM O[block_start_Br:block_end_Br, :] = Oi print(torch.allclose(O, expected_attention))
flash attention2 实现:
import torch torch.manual_seed(456) N, d = 16, 8 Q_mat = torch.rand((N, d)) K_mat = torch.rand((N, d)) V_mat = torch.rand((N, d)) # 执行标准的pytorch softmax和attention计算 expected_softmax = torch.softmax(Q_mat @ K_mat.T, dim=1) expected_attention = expected_softmax @ V_mat # 分块(tiling)尺寸,以SRAM的大小计算得到 Br = 4 Bc = d O = torch.zeros((N, d)) # 算法流程第3步,执行外循环 for block_start_Br in range(0, N, Br): block_end_Br = block_start_Br + Br # 算法流程第4步,从HBM中load Qi 的一个block到SRAM Qi = Q_mat[block_start_Br:block_end_Br, :] # 算法流程第5步,初始化每个block的值 Oi = torch.zeros((Br, d)) # shape Br x d li = torch.zeros((Br, 1)) # shape Br x 1 mi = torch.full((Br, 1), -torch.inf) # shape Br x 1 # 算法流程第6步,执行内循环 for block_start_Bc in range(0, N, Bc): block_end_Bc = block_start_Bc + Bc # 算法流程第7步,load Kj, Vj到SRAM Kj = K_mat[block_start_Bc:block_end_Bc, :] Vj = V_mat[block_start_Bc:block_end_Bc, :] # 算法流程第8步 Sij = Qi @ Kj.T # 算法流程第9步 mi_new = torch.max(torch.column_stack([mi, torch.max(Sij, dim=1).values[:, None]]), dim=1).values[:, None] Pij_hat = torch.exp(Sij - mi_new) li = torch.exp(mi - mi_new) * li + torch.sum(Pij_hat, dim=1)[:, None] # 算法流程第10步 Oi = Oi * torch.exp(mi - mi_new) + Pij_hat @ Vj mi = mi_new # 第12步 Oi = Oi / li # 第14步 O[block_start_Br:block_end_Br, :] = Oi print(torch.allclose(O, expected_attention))
import torch import torch.nn.functional as F from rich import print from torch.backends.cuda import sdp_kernel #内核计算 from enum import IntEnum import torch.utils.benchmark as benchmark device = "cuda" if torch.cuda.is_available() else "cpu" #cudnn 需要使用gpu # 超参数定义 batch_size = 64 max_sequence_len = 256 num_heads = 32 embed_dimension = 32 dtype = torch.float16 # 模拟 q k v query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) # 定义一个计时器: def torch_timer(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e6 # torch.backends.cuda中也实现了,这里拿出了为了好理解backend_map是啥 class SDPBackend(IntEnum): r""" Enum class for the scaled dot product attention backends. """ ERROR = -1 MATH = 0 FLASH_ATTENTION = 1 EFFICIENT_ATTENTION = 2 # 使用上下文管理器context manager来 # 其他三种方案,字典映射 backend_map = { SDPBackend.MATH: { #启用pytorch 实现 "enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, SDPBackend.FLASH_ATTENTION: { #启用flashattention "enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, SDPBackend.EFFICIENT_ATTENTION: { #启用memory_efficient attention "enable_math": False, "enable_flash": False, "enable_mem_efficient": True} } # 基本版,不指定 print(f"基本对照方案 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") # 基本对照方案 运行时间: 558.831 microseconds #内核中运行 with sdp_kernel(**backend_map[SDPBackend.MATH]): print(f"math 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") # math 运行时间: 1013.422 microseconds with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): try: print(f"flash attention 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") except RuntimeError: print("FlashAttention is not supported") # flash attention 运行时间: 557.343 microseconds with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): try: print(f"Memory efficient 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") except RuntimeError: print("EfficientAttention is not supported") # Memory efficient 运行时间: 428.007 microseconds
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。