赞
踩
论文:https://arxiv.org/pdf/2205.14135.pdf
FlashAttention 提出的是为了解决 Transformers 在处理长序列时的速度慢和内存消耗大的问题。
这个问题主要是因为,自注意力模块在长序列上的时间和内存复杂度都是二次方的。
FlashAttention的本质是通过创新的算法设计,实现了对Transformer模型中注意力机制的高效计算。
FlashAttention通过减少HBM访问次数和避免存储大型中间矩阵,使BERT模型比MLPerf 1.1的速度记录快15%,GPT-2的训练速度提高了最高3倍。
使用FlashAttention的GPT-2模型,在4K的上下文长度下训练比Megatron在1K上下文长度下训练还快,同时困惑度(perplexity)更低,说明模型质量提高。
FlashAttention在常见序列长度(最高2K)上比标准注意力实现快3倍,并且其内存占用随序列长度线性增长,证明了其在效率和内存使用上的优势。
块稀疏FlashAttention通过仅计算重要的注意力块来减少计算量和内存使用,使得Transformer模型能够处理高达64K序列长度,且在Path-256任务上达到了63.1%的准确率,显示了其在处理长序列任务上的能力。
它通过以下核心方法和策略,解决了传统注意力计算在长序列处理时遇到的速度慢和内存消耗大的问题:
IO-感知优化:FlashAttention深入考虑了GPU内存层次之间的交互,特别是高带宽内存(HBM)与片上SRAM之间的读写操作,通过优化这些操作来减少内存访问成本,从而提高计算效率。
分块计算(Tiling):通过将输入序列分成小块并逐块处理,FlashAttention避免了一次性加载整个序列到内存中,减轻了内存压力,并使得注意力计算更加高效。
重计算策略:为了减少后向传播时对大型中间矩阵的存储需求,FlashAttention采用了在需要时重新计算这些矩阵的策略,从而节省了大量的内存空间。
核心融合:FlashAttention通过将多个计算步骤融合到一个CUDA核心中执行,减少了内存访问次数,并提高了执行速度。
这些策略共同作用,使FlashAttention能够以更少的内存访问和更低的时间复杂度,准确地计算出注意力,从而在保持模型质量的同时,显著提高了训练速度和效率。
此外,FlashAttention的设计还支持块稀疏注意力,进一步提高了处理长序列能力,使得在资源有限的情况下,Transformer模型能够处理更长的上下文信息,这在自然语言处理和其他需要长序列处理的领域中尤为重要。
FlashAttention本质上是对传统Transformer注意力机制的一个高效、内存友好的改进,它通过深入挖掘和优化计算机内存和计算资源的使用方式,推动了深度学习模型在复杂任务上的应用和发展。
在标准注意力计算中,每个操作(如 softmax、矩阵乘法等)都需要从 HBM 读取输入,计算后再将结果写回 HBM,导致高内存访问成本。
如果我们可以将多个操作合并为一个操作(核心融合),那么输入只需从 HBM 加载一次,这样就减少了内存访问次数,从而降低了内存访问成本。
在标准注意力机制中,整个注意力矩阵需要一次性计算并存储,导致对 HBM 的大量访问。
通过将输入矩阵 Q、K、V 分块并逐块计算,我们可以逐步生成注意力输出,减少了一次性对大量数据的访问需求。
一个大型矩阵乘法,通过将矩阵分为小块,每次只处理一部分数据,就可以减少内存的即时需求。
通过子解法的组合,FlashAttention 成功地解决了 Transformers 在处理长序列时速度慢和内存消耗大的问题。
FlashAttention 提出了一种计算精确注意力的算法,其关键在于通过减少对高带宽内存(HBM)的读写操作以及避免在后向传递中存储大型中间矩阵,从而实现了既节省内存又加速计算的目标。
在探索传统注意力机制在现代硬件(尤其是 GPU)上的执行效率时,遇到了一系列的具体问题,这些问题导致了处理速度慢和高内存消耗。
每种解决方案都直接针对了标准注意力实现中的效率瓶颈,通过改善内存访问模式、减少不必要的内存写入和读取、以及优化计算流程来提高整体性能。
左侧:展示了在GPU中的内存层次结构和FlashAttention如何在这种结构中工作。
它说明了:
右侧:显示了使用PyTorch实现的注意力计算与FlashAttention实现在GPT-2模型上的速度对比。
它说明了:
从提高开发效率、扩展IO-感知优化的应用范围,到优化多GPU并行计算的效率。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。