当前位置:   article > 正文

速度飙升200%!Flash Attention 2一统江湖,注意力计算不再是问题!

flashattention-2 is not installed

❤️点击上方,选择星标置顶,每天给你送上干货❤️

作者 | godweiyang

出品 | 公众号:算法码上来(ID:GodNLP)

- BEGIN -

attention是Transformer中最重要的一个结构,但是随着序列长度的增加,计算复杂度以增长,显存和速度都会吃不消。因此很多attention加速算法被提了出来,例如flash attention、xformers等等。

就在7.17日,flash attention 2开源了,官方宣称比1代还要快2倍左右,于是我迫不及待就安装试了一下,看看到底有多大提升。
https://crfm.stanford.edu/2023/07/17/flash2.html

这次的测试对象有4个,分别是PyTorch手工实现的attention、torch.nn.functional提供的_scaled_dot_product_attention算子、flash attention 2官方实现、xformers官方实现。

直接说结论吧,大部分情况下,速度和显存都是「flash attention 2 > xformers > PyTorch function > 手工PyTorch实现」

测试环境

  • A100-SXM4-80g,因为flash attention 2只支持A和H系列显卡。

  • PyTorch 1.13.1

  • CUDA 11.7

安装命令

  1. pip install ninja triton
  2. # flash attention
  3. pip install flash-attn --no-build-isolation
  4. # xformers
  5. pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers

测试代码

  1. import math
  2. import random
  3. import time
  4. from einops import rearrange
  5. import torch
  6. import torch.nn.functional as F
  7. from flash_attn import flash_attn_func
  8. from xformers.ops import memory_efficient_attention, LowerTriangularMask
  9. xformers_attn_bias = LowerTriangularMask()
  10. def custom_attention(q, k, v, causal=False):
  11.     score = torch.matmul(q, k.transpose(-2-1)) / math.sqrt(q.size(-1))
  12.     if causal:
  13.         mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
  14.         mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
  15.         mask = mask.to(q.device, q.dtype)
  16.         score = score + mask
  17.     attn = F.softmax(score, dim=-1)
  18.     o = torch.matmul(attn, v)
  19.     return o
  20. def pytorch_func(q, k, v, causal=False):
  21.     o = F._scaled_dot_product_attention(q, k, v, is_causal=causal)[0]
  22.     return o
  23. def flash_attention(q, k, v, causal=False):
  24.     o = flash_attn_func(q, k, v, causal=causal)
  25.     return o
  26. def xformers_attention(q, k, v, causal=False):
  27.     attn_bias = xformers_attn_bias if causal else None
  28.     o = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
  29.     return o
  30. def test(func_name, q, k, v, *args, **kwargs):
  31.     if func_name in ["custom_attention""pytorch_func"]:
  32.         q = rearrange(q, "a b c d -> a c b d")
  33.         k = rearrange(k, "a b c d -> a c b d")
  34.         v = rearrange(v, "a b c d -> a c b d")
  35.     torch.cuda.reset_peak_memory_stats()
  36.     torch.cuda.synchronize()
  37.     for _ in range(5):
  38.         o = globals()[func_name](q, k, v, *args, **kwargs)
  39.     torch.cuda.synchronize()
  40.     st = time.time()
  41.     o = globals()[func_name](q, k, v, *args, **kwargs)
  42.     torch.cuda.synchronize()
  43.     tt = time.time() - st
  44.     max_memory = torch.cuda.max_memory_allocated() // 2**20
  45.     torch.cuda.empty_cache()
  46.     if func_name in ["custom_attention""pytorch_func"]:
  47.         o = rearrange(o, "a c b d -> a b c d")
  48.     return o, tt, max_memory
  49. if __name__ == "__main__":
  50.     test_num = 10
  51.     for idx in range(test_num):
  52.         print(f"test {idx} >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
  53.         bsz = random.randint(164)
  54.         sql = random.randint(14096)
  55.         nh = random.choice([81216])
  56.         hd = random.choice([64128])
  57.         dtype = random.choice([torch.float16, torch.bfloat16])
  58.         causal = random.choice([False, True])
  59.         print(f"shape: ({bsz}, {sql}, {nh}, {hd}), dtype: {dtype}, causal: {causal}")
  60.         q = torch.randn((bsz, sql, nh, hd)).to("cuda:0", dtype)
  61.         k = torch.rand_like(q)
  62.         v = torch.rand_like(q)
  63.         o, t, m = test("custom_attention", q, k, v, causal=causal)
  64.         print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")
  65.         pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=causal)
  66.         print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
  67.         assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
  68.         
  69.         fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=causal)
  70.         print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
  71.         assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)
  72.         xf_o, xf_t, xf_m = test("xformers_attention", q, k, v, causal=causal)
  73.         print(f"xformers time: {xf_t:.6f}, speedup: {t/xf_t:.2f}; peak memory: {xf_m} MB, save: {int((m-xf_m)/m*100)}%")
  74.         assert torch.allclose(o, xf_o, rtol=1e-2, atol=1e-2)

测试结果

测试了10组随机输入shape(batch_size, seq_len, num_head, head_dim),随机开启causal mask,随机fp16或bf16,结果如下:

  1. test 0 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  2. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
  3. custom pytorch time: 0.000734, peak memory: 105 MB
  4. pytorch func time: 0.000104, speedup: 7.06; peak memory: 49 MB, save: 53%
  5. flash attention time: 0.000055, speedup: 13.45; peak memory: 43 MB, save: 59%
  6. xformers time: 0.000152, speedup: 4.82; peak memory: 61 MB, save: 41%
  7. test 1 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  8. shape: (57, 1235, 8, 64), dtype: torch.float16, causal: True
  9. custom pytorch time: 0.015195, peak memory: 3093 MB
  10. pytorch func time: 0.001930, speedup: 7.87; peak memory: 571 MB, save: 81%
  11. flash attention time: 0.000635, speedup: 23.94; peak memory: 496 MB, save: 83%
  12. xformers time: 0.001383, speedup: 10.99; peak memory: 696 MB, save: 77%
  13. test 2 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  14. shape: (61, 2045, 16, 128), dtype: torch.bfloat16, causal: True
  15. custom pytorch time: 0.101898, peak memory: 18782 MB
  16. pytorch func time: 0.031511, speedup: 3.23; peak memory: 4115 MB, save: 78%
  17. flash attention time: 0.005292, speedup: 19.25; peak memory: 3560 MB, save: 81%
  18. xformers time: 0.009730, speedup: 10.47; peak memory: 3972 MB, save: 78%
  19. test 3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  20. shape: (15, 1526, 12, 64), dtype: torch.float16, causal: True
  21. custom pytorch time: 0.010720, peak memory: 3756 MB
  22. pytorch func time: 0.001101, speedup: 9.74; peak memory: 1732 MB, save: 53%
  23. flash attention time: 0.000380, speedup: 28.24; peak memory: 1211 MB, save: 67%
  24. xformers time: 0.000862, speedup: 12.43; peak memory: 824 MB, save: 78%
  25. test 4 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  26. shape: (28, 3227, 12, 128), dtype: torch.float16, causal: True
  27. custom pytorch time: 0.091987, peak memory: 15090 MB
  28. pytorch func time: 0.029867, speedup: 3.08; peak memory: 2223 MB, save: 85%
  29. flash attention time: 0.004636, speedup: 19.84; peak memory: 1924 MB, save: 87%
  30. xformers time: 0.008405, speedup: 10.94; peak memory: 2151 MB, save: 85%
  31. test 5 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  32. shape: (37, 2047, 8, 64), dtype: torch.bfloat16, causal: True
  33. custom pytorch time: 0.026797, peak memory: 6242 MB
  34. pytorch func time: 0.003424, speedup: 7.83; peak memory: 1388 MB, save: 77%
  35. flash attention time: 0.000947, speedup: 28.29; peak memory: 1049 MB, save: 83%
  36. xformers time: 0.002072, speedup: 12.93; peak memory: 1006 MB, save: 83%
  37. test 6 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  38. shape: (24, 2637, 16, 128), dtype: torch.bfloat16, causal: False
  39. custom pytorch time: 0.053066, peak memory: 11970 MB
  40. pytorch func time: 0.047200, speedup: 1.12; peak memory: 2205 MB, save: 81%
  41. flash attention time: 0.006308, speedup: 8.41; peak memory: 1885 MB, save: 84%
  42. xformers time: 0.011971, speedup: 4.43; peak memory: 2055 MB, save: 82%
  43. test 7 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  44. shape: (37, 3214, 12, 64), dtype: torch.float16, causal: True
  45. custom pytorch time: 0.097363, peak memory: 19552 MB
  46. pytorch func time: 0.012316, speedup: 7.91; peak memory: 2142 MB, save: 89%
  47. flash attention time: 0.003399, speedup: 28.65; peak memory: 1720 MB, save: 91%
  48. xformers time: 0.007016, speedup: 13.88; peak memory: 1995 MB, save: 89%
  49. test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  50. shape: (40, 2126, 16, 64), dtype: torch.float16, causal: True
  51. custom pytorch time: 0.066542, peak memory: 12737 MB
  52. pytorch func time: 0.006069, speedup: 10.96; peak memory: 1856 MB, save: 85%
  53. flash attention time: 0.002226, speedup: 29.89; peak memory: 1516 MB, save: 88%
  54. xformers time: 0.004234, speedup: 15.71; peak memory: 1840 MB, save: 85%
  55. test 9 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
  56. shape: (47, 3355, 12, 64), dtype: torch.bfloat16, causal: False
  57. custom pytorch time: 0.100385, peak memory: 26267 MB
  58. pytorch func time: 0.024839, speedup: 4.04; peak memory: 2353 MB, save: 91%
  59. flash attention time: 0.008755, speedup: 11.47; peak memory: 1956 MB, save: 92%
  60. xformers time: 0.016346, speedup: 6.14; peak memory: 2483 MB, save: 90%

可以看出,在大多数情况下,速度和显存都是「flash attention 2 > xformers > PyTorch function > 手工PyTorch实现」

而且几个方法的API都非常好用,基本可以直接替换你自己模型里的attention模块。但是flash attention 2貌似不支持传入attention mask,只能指定causal mask,因此有一定的局限性,用在gpt里还是足够了。

- END -

我是godweiyang,字节跳动AI Lab NLP算法工程师,华师计算机本硕均专业第一,擅长算法模型优化机器翻译

回复【算法

获取我面试时写过的100多道算法题解,刷完进大厂没问题。

回复【CUDA

获取我为新手准备的CUDA入门系列教程。

回复【内推

内推字节,通过率高,加我微信可随时催进度咨询问题

回复【加群

进我的技术交流(聊天)群和内推群,群内有字节HR答疑

52755580e08843d0f847f51205fff3e5.png

求求兄弟们点个在看吧,今天的阅读量靠你们了

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/270337
推荐阅读
相关标签