赞
踩
❤️点击上方,选择星标或置顶,每天给你送上干货❤️
作者 | godweiyang
出品 | 公众号:算法码上来(ID:GodNLP)
- BEGIN -
attention是Transformer中最重要的一个结构,但是随着序列长度的增加,计算复杂度以增长,显存和速度都会吃不消。因此很多attention加速算法被提了出来,例如flash attention、xformers等等。
就在7.17日,flash attention 2开源了,官方宣称比1代还要快2倍左右,于是我迫不及待就安装试了一下,看看到底有多大提升。
https://crfm.stanford.edu/2023/07/17/flash2.html
这次的测试对象有4个,分别是PyTorch手工实现的attention、torch.nn.functional
提供的_scaled_dot_product_attention
算子、flash attention 2官方实现、xformers官方实现。
直接说结论吧,大部分情况下,速度和显存都是「flash attention 2 > xformers > PyTorch function > 手工PyTorch实现」。
A100-SXM4-80g,因为flash attention 2只支持A和H系列显卡。
PyTorch 1.13.1
CUDA 11.7
- pip install ninja triton
-
- # flash attention
- pip install flash-attn --no-build-isolation
-
- # xformers
- pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
- import math
- import random
- import time
- from einops import rearrange
- import torch
- import torch.nn.functional as F
- from flash_attn import flash_attn_func
- from xformers.ops import memory_efficient_attention, LowerTriangularMask
-
-
- xformers_attn_bias = LowerTriangularMask()
-
- def custom_attention(q, k, v, causal=False):
- score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
- if causal:
- mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
- mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
- mask = mask.to(q.device, q.dtype)
- score = score + mask
- attn = F.softmax(score, dim=-1)
- o = torch.matmul(attn, v)
- return o
-
- def pytorch_func(q, k, v, causal=False):
- o = F._scaled_dot_product_attention(q, k, v, is_causal=causal)[0]
- return o
-
- def flash_attention(q, k, v, causal=False):
- o = flash_attn_func(q, k, v, causal=causal)
- return o
-
- def xformers_attention(q, k, v, causal=False):
- attn_bias = xformers_attn_bias if causal else None
- o = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
- return o
-
- def test(func_name, q, k, v, *args, **kwargs):
- if func_name in ["custom_attention", "pytorch_func"]:
- q = rearrange(q, "a b c d -> a c b d")
- k = rearrange(k, "a b c d -> a c b d")
- v = rearrange(v, "a b c d -> a c b d")
-
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.synchronize()
- for _ in range(5):
- o = globals()[func_name](q, k, v, *args, **kwargs)
- torch.cuda.synchronize()
- st = time.time()
- o = globals()[func_name](q, k, v, *args, **kwargs)
- torch.cuda.synchronize()
- tt = time.time() - st
- max_memory = torch.cuda.max_memory_allocated() // 2**20
- torch.cuda.empty_cache()
-
- if func_name in ["custom_attention", "pytorch_func"]:
- o = rearrange(o, "a c b d -> a b c d")
-
- return o, tt, max_memory
-
- if __name__ == "__main__":
- test_num = 10
- for idx in range(test_num):
- print(f"test {idx} >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
- bsz = random.randint(1, 64)
- sql = random.randint(1, 4096)
- nh = random.choice([8, 12, 16])
- hd = random.choice([64, 128])
- dtype = random.choice([torch.float16, torch.bfloat16])
- causal = random.choice([False, True])
- print(f"shape: ({bsz}, {sql}, {nh}, {hd}), dtype: {dtype}, causal: {causal}")
- q = torch.randn((bsz, sql, nh, hd)).to("cuda:0", dtype)
- k = torch.rand_like(q)
- v = torch.rand_like(q)
-
- o, t, m = test("custom_attention", q, k, v, causal=causal)
- print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")
-
- pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=causal)
- print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
- assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
-
- fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=causal)
- print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
- assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)
-
- xf_o, xf_t, xf_m = test("xformers_attention", q, k, v, causal=causal)
- print(f"xformers time: {xf_t:.6f}, speedup: {t/xf_t:.2f}; peak memory: {xf_m} MB, save: {int((m-xf_m)/m*100)}%")
- assert torch.allclose(o, xf_o, rtol=1e-2, atol=1e-2)
测试了10组随机输入shape(batch_size, seq_len, num_head, head_dim),随机开启causal mask,随机fp16或bf16,结果如下:
- test 0 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
- custom pytorch time: 0.000734, peak memory: 105 MB
- pytorch func time: 0.000104, speedup: 7.06; peak memory: 49 MB, save: 53%
- flash attention time: 0.000055, speedup: 13.45; peak memory: 43 MB, save: 59%
- xformers time: 0.000152, speedup: 4.82; peak memory: 61 MB, save: 41%
- test 1 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (57, 1235, 8, 64), dtype: torch.float16, causal: True
- custom pytorch time: 0.015195, peak memory: 3093 MB
- pytorch func time: 0.001930, speedup: 7.87; peak memory: 571 MB, save: 81%
- flash attention time: 0.000635, speedup: 23.94; peak memory: 496 MB, save: 83%
- xformers time: 0.001383, speedup: 10.99; peak memory: 696 MB, save: 77%
- test 2 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (61, 2045, 16, 128), dtype: torch.bfloat16, causal: True
- custom pytorch time: 0.101898, peak memory: 18782 MB
- pytorch func time: 0.031511, speedup: 3.23; peak memory: 4115 MB, save: 78%
- flash attention time: 0.005292, speedup: 19.25; peak memory: 3560 MB, save: 81%
- xformers time: 0.009730, speedup: 10.47; peak memory: 3972 MB, save: 78%
- test 3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (15, 1526, 12, 64), dtype: torch.float16, causal: True
- custom pytorch time: 0.010720, peak memory: 3756 MB
- pytorch func time: 0.001101, speedup: 9.74; peak memory: 1732 MB, save: 53%
- flash attention time: 0.000380, speedup: 28.24; peak memory: 1211 MB, save: 67%
- xformers time: 0.000862, speedup: 12.43; peak memory: 824 MB, save: 78%
- test 4 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (28, 3227, 12, 128), dtype: torch.float16, causal: True
- custom pytorch time: 0.091987, peak memory: 15090 MB
- pytorch func time: 0.029867, speedup: 3.08; peak memory: 2223 MB, save: 85%
- flash attention time: 0.004636, speedup: 19.84; peak memory: 1924 MB, save: 87%
- xformers time: 0.008405, speedup: 10.94; peak memory: 2151 MB, save: 85%
- test 5 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (37, 2047, 8, 64), dtype: torch.bfloat16, causal: True
- custom pytorch time: 0.026797, peak memory: 6242 MB
- pytorch func time: 0.003424, speedup: 7.83; peak memory: 1388 MB, save: 77%
- flash attention time: 0.000947, speedup: 28.29; peak memory: 1049 MB, save: 83%
- xformers time: 0.002072, speedup: 12.93; peak memory: 1006 MB, save: 83%
- test 6 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (24, 2637, 16, 128), dtype: torch.bfloat16, causal: False
- custom pytorch time: 0.053066, peak memory: 11970 MB
- pytorch func time: 0.047200, speedup: 1.12; peak memory: 2205 MB, save: 81%
- flash attention time: 0.006308, speedup: 8.41; peak memory: 1885 MB, save: 84%
- xformers time: 0.011971, speedup: 4.43; peak memory: 2055 MB, save: 82%
- test 7 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (37, 3214, 12, 64), dtype: torch.float16, causal: True
- custom pytorch time: 0.097363, peak memory: 19552 MB
- pytorch func time: 0.012316, speedup: 7.91; peak memory: 2142 MB, save: 89%
- flash attention time: 0.003399, speedup: 28.65; peak memory: 1720 MB, save: 91%
- xformers time: 0.007016, speedup: 13.88; peak memory: 1995 MB, save: 89%
- test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (40, 2126, 16, 64), dtype: torch.float16, causal: True
- custom pytorch time: 0.066542, peak memory: 12737 MB
- pytorch func time: 0.006069, speedup: 10.96; peak memory: 1856 MB, save: 85%
- flash attention time: 0.002226, speedup: 29.89; peak memory: 1516 MB, save: 88%
- xformers time: 0.004234, speedup: 15.71; peak memory: 1840 MB, save: 85%
- test 9 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
- shape: (47, 3355, 12, 64), dtype: torch.bfloat16, causal: False
- custom pytorch time: 0.100385, peak memory: 26267 MB
- pytorch func time: 0.024839, speedup: 4.04; peak memory: 2353 MB, save: 91%
- flash attention time: 0.008755, speedup: 11.47; peak memory: 1956 MB, save: 92%
- xformers time: 0.016346, speedup: 6.14; peak memory: 2483 MB, save: 90%
可以看出,在大多数情况下,速度和显存都是「flash attention 2 > xformers > PyTorch function > 手工PyTorch实现」。
而且几个方法的API都非常好用,基本可以直接替换你自己模型里的attention模块。但是flash attention 2貌似不支持传入attention mask,只能指定causal mask,因此有一定的局限性,用在gpt里还是足够了。
- END -
我是godweiyang,字节跳动AI Lab NLP算法工程师,华师计算机本硕均专业第一,擅长算法、模型优化和机器翻译。
回复【算法】
获取我面试时写过的100多道算法题解,刷完进大厂没问题。
回复【CUDA】
获取我为新手准备的CUDA入门系列教程。
回复【内推】
内推字节,通过率高,加我微信可随时查催进度、咨询问题。
回复【加群】
进我的技术交流(聊天)群和内推群,群内有字节HR答疑。
求求兄弟们点个在看吧,今天的阅读量靠你们了
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。