赞
踩
FlashAttention-2提出后,便得到了大量关注。本文将具体讲述FlashAttention-2的前世今生,包括FlashAttention1&2的原理解析、加速效果比较以及面向AIGC的加速实践,在这里将相关内容与大家分享~
引言
将 Transformers 扩展到更长的序列长度一直是过去几年的一个热点问题,这将有助于提高语言建模和高分辨率图像理解的能力,也有利于音频和视频生成方面的新应用场景研发。Attention层是扩展到更长序列的主要瓶颈,因为它的运行时间和内存占用是序列长度的二次方。使用近似计算的Attention方法,可以通过减少FLOP计算次数、甚至于牺牲模型质量来降低计算复杂性,但通常无法实现大比例的加速。
由斯坦福大学提出的FlashAttention方法,让使用更长sequence计算Attention成为可能,并且通过线性级别的增长来节省内存以及加速计算。因为FlashAttention没有进行近似计算,所以也没有精度损失。然而,FlashAttention的实际速度仍然和理论上的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。效率低下的原因主要是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。为此,2023年7月,论文作者进一步提出了FlashAttention-2,实现了Attention计算速度的大幅度提升。
FlashAttention
FlashAttention主要关注IO-aware,进一步优化GPU显存的读写效率。这是一种 IO 感知的精确Attention算法,它使用tiling(这里可以理解为分块)来减少 GPU 高带宽内存 (HBM) 和 GPU 片上 SRAM 之间的内存读/写次数。这里的HBM可以理解为显存,SRAM可以理解为cache。通过测试IO复杂性,相比标准 Attention,FlashAttention需要更少的 HBM 访问,并且对于不同的SRAM 大小来说都是有效的。除此以外,FlashAttention还可以扩展到block-sparse attention,产生比任何现有近似注意力方法更快的近似注意力算法。
FlashAttention与 MLPerf 1.1 训练速度相比,对于BERT-large(序列长度 512)实现端到端wall-clock加速15%,对于GPT-2(序列长度 1K)加速 3 倍。FlashAttention 和block-sparse FlashAttention 可在 Transformers 中实现更长的上下文,从而产生更高质量的模型,GPT-2 上的困惑度提升0.7,长文档分类的test结果提高 6.4 个点。
背景知识:
上图的左图,表示存储结构,可以简单理解为:SRAM表示缓存,HBM表示显存,DRAM表示内存。
在不访问整个输入的情况下优化attention计算,并减少相关计算量。重构attention计算,将输入分割成块,并对分块进行多次传递,从而逐步执行attention计算(该步骤称为tiling)。
如上图所示,FlashAttention 使用tiling来防止在相对较慢的 GPU显存上实现大型 声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。