赞
踩
FlashAttention一般指的是FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness这篇,当然Transformer Quality in Linear Time这篇里非要说FLASH = Fast Linear Attention with a Single Head,命名有点无语,关于FLASH的细节参考 FLASH:可能是近来最有意思的高效Transformer设计 ,下面重点写写FlashAttention:
tiling中文是瓦片化,实际上就是把计算像瓦片一样铺向SRAM,保证运算不要频繁在SRAM和**HBM(High-Bandwidth Memory,HBM是高带宽内存,也就是我们常说的显存)**频繁切换,提高速度。
对于标准注意力实现,初期我们需要把输入
Q
,
K
,
V
\mathbf{Q}, \mathbf{K}, \mathbf{V}
Q,K,V从HBM中读取,并计算完毕后把输出
O
\mathbf{O}
O写入到HBM中。
第一步把
Q
,
K
\mathbf{Q}, \mathbf{K}
Q,K读取出来计算出
S
=
Q
K
⊤
\mathbf{S}=\mathbf{Q K}^{\top}
S=QK⊤,然后把
S
\mathbf{S}
S存回去,内存访问复杂度
Θ
(
N
d
+
N
2
)
\Theta\left(N d+N^2\right)
Θ(Nd+N2)。
第二步把 S \mathbf{S} S读取出来计算出 P = softmax ( S ) \mathbf{P}=\operatorname{softmax}(\mathbf{S}) P=softmax(S),然后把 P \mathbf{P} P存回去,内存访问复杂度 Θ ( N 2 ) \Theta\left(N^2\right) Θ(N2)。
第三步把 V , P \mathbf{V}, \mathbf{P} V,P读取出来计算出 O = P V \mathbf{O}=\mathbf{P} \mathbf{V} O=PV,然后计算出结果 O \mathbf{O} O,内存访问复杂度 Θ ( N d + N 2 ) \Theta\left(N d+N^2\right) Θ(Nd+N2)。
综上所述,整体的内存访问复杂度为 Θ ( N d + N 2 ) \Theta\left(N d+N^2\right) Θ(Nd+N2)。
FlashAttention关键的想法就是tile(分块),把QKV都拆成块。这里一个关键点是softmax怎么算,有点绕,简单说就是把每部分分子分母的和给存下来,归一化到相同的比例。下面是个具体的例子,
l
_
p
r
e
l\_pre
l_pre是分母缩最大倍数后的和,也是最绕的点。假设QK结果是[1,2],那么softmax结果就是
[
e
1
e
1
+
e
2
,
e
2
e
1
+
e
2
]
[\frac{e^1}{e^1+e^2},\frac{e^2}{e^1+e^2}]
[e1+e2e1,e1+e2e2]
再乘以V的结果就是:
e
1
∗
v
1
e
1
+
e
2
+
e
2
∗
v
2
e
1
+
e
2
\frac{e^1*v_1}{e^1+e^2}+\frac{e^2*v_2}{e^1+e^2}
e1+e2e1∗v1+e1+e2e2∗v2
如果拆成两步算,第一步:
c
u
r
_
s
u
m
=
e
1
∗
v
1
e
1
m
_
p
r
e
=
m
a
x
(
e
1
)
=
e
1
,
是分子
e
的和
l
_
p
r
e
=
s
u
m
(
e
1
)
=
e
1
,
是分母
e
的和
cur\_sum = \frac{e^1*v_1}{e^1} \\ m\_pre = max(e^1)=e^1,是分子e的和 \\ l\_pre = sum(e^1)=e^1,是分母e的和
cur_sum=e1e1∗v1m_pre=max(e1)=e1,是分子e的和l_pre=sum(e1)=e1,是分母e的和
第二步:
m
_
c
u
r
=
m
a
x
(
e
2
,
m
_
p
r
e
)
=
e
2
l
_
p
r
e
∗
=
e
m
_
p
r
e
−
m
_
c
u
r
=
e
1
−
2
,分母缩共同倍数后相加
l
_
c
u
r
=
s
u
m
(
e
2
−
2
)
+
l
_
p
r
e
c
u
r
_
s
u
m
=
c
u
r
_
s
u
m
∗
l
_
p
r
e
l
_
c
u
r
=
e
1
∗
v
1
e
1
∗
e
−
1
e
−
1
+
e
0
c
u
r
_
s
u
m
+
=
v
2
∗
c
u
r
_
s
u
m
l
_
p
r
e
=
e
1
∗
v
1
e
1
+
e
2
+
e
2
∗
v
2
e
1
+
e
2
m\_cur = max(e^2,m\_pre)=e^2 \\ l\_pre *= e^{m\_pre - m\_cur}=e^{1-2} ,分母缩共同倍数后相加\\ l\_cur = sum(e^{2-2})+l\_pre\\ cur\_sum=cur\_sum*\frac{l\_pre}{l\_cur}=\frac{e^1*v_1}{e^1}*\frac{e^{-1}}{e^{-1}+e^0}\\ cur\_sum+=\frac{v_2*cur\_sum}{l\_pre}=\frac{e^1*v_1}{e^1+e^2}+\frac{e^2*v_2}{e^1+e^2}
m_cur=max(e2,m_pre)=e2l_pre∗=em_pre−m_cur=e1−2,分母缩共同倍数后相加l_cur=sum(e2−2)+l_precur_sum=cur_sum∗l_curl_pre=e1e1∗v1∗e−1+e0e−1cur_sum+=l_prev2∗cur_sum=e1+e2e1∗v1+e1+e2e2∗v2
这样,在前向的过程中,我们采用分块计算的方式,避免了矩阵的存储开销,整体的运算都在SRAM内进行,降低了HBM访问次数,大大提升了计算的速度,减少了对存储的消耗。详细的复杂度分析可以参考原文和https://readpaper.feishu.cn/docx/AC7JdtLrhoKpgxxSRM8cfUounsh
我们这里则采用重新计算的方式来计算对应的梯度。在上面前向计算的时候我们不会存储
S
,
P
\mathbf{S}, \mathbf{P}
S,P矩阵,但是我们会存储对应的指数项之和
L
L
L来进行梯度的计算。这里不展开写了,细节可以参考原文和https://readpaper.feishu.cn/docx/AC7JdtLrhoKpgxxSRM8cfUounsh
目前,Flash Attention已经集成至torch2.0,并且社区也提供了多种实现
源自vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention这篇paper,关键的技术有两点PagedAttention和内存共享:
KV Cache是大模型推理优化的一个常用技术,该技术以空间换时间的思想,通过使用上次推理的KV缓存,可以在不影响任何计算精度的前提下,提高推理性能,降低端到端的时延。
以GPT为代表的Decoder-Only自回归语言模型在生成每一个新的 token 时,接受所有之前生成的 tokens 作为输入。然而,对于这些先前生成的 tokens,每次生成新的 token 时都需要重新计算他们的表示,这个过程造成了大量的计算浪费。KV Cache 的引入就是为了解决这个问题。
KV Cache实质上是存储了之前计算过的 key-value 对用于下一个Token的生成。在 Transformer 结构中,self-attention 中的k_proj, v_proj会将输入的每个 token 转化为一个 key 和一个 value,然后使用这些 key-value 以及当前的query对来计算下一个 token。引入 KV Cache,我们就可以将之前生成的 tokens 对应的 key-value 对存储起来,当生成新的 token 时,直接从 KV Cache 中取出这些已经计算好的 key-value 对,再把当前token的key-value做一个连结在进行计算,这样就避免了KV的重复计算,大大提高了计算效率。
到huggingface代码里看,例如Hugging Face的transformers库代码实现就比较清爽,在modeling_gpt2.py中Attention部分相关代码如下:
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None: # 当输出第一个token后,layer_past就是非None了
past_key, past_value = layer_past # 取出之前计算好的 key, value
key = torch.cat((past_key, key), dim=-2) # past_key 与当前 token 对应的 key 拼接
value = torch.cat((past_value, value), dim=-2) # past_value 与当前 token 对应的 value 拼接
if use_cache is True:
present = (key, value)
else:
present = None
整体来说,使用KV Cache包含以下两个步骤:
beamsearch、topk sampling、nucleus sampling等解码策略在hugging face的GenerationMixin(transformers/generation/utils.py)中均有所实现,在hugging face上的生成式模型都要继承GenerationMixin,以beamsearch为例,下面self就是继承的子类提供的根据 w 0.. i − 1 w_{0..i-1} w0..i−1给 w i w_{i} wi打分的language model,这个language model里当然要实现例如kv_cache等策略:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
通过KV Cache的技术,我们已经可以极大地提升LLM地推理速度,但是现有的Cache仍存在一些问题,
到vLLM开源的代码里去看,实际上PagedAttention是以加了一个cuda的核函数实现的,核心是single_query_cached_kv_attention_kernel函数的工作(https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu)
由于块在内存中不需要连续存储,我们可以像操作系统的虚拟内存那样以更加灵活的方式管理键和值的缓存:可以将块看作页,标记看作字节,序列看作进程。序列的连续逻辑块通过块表映射到非连续的物理块。随着生成新的标记,序列的边长,物理块按需进行分配。
在PagedAttention中,内存浪费仅发生在序列的最后一个块中。这样就使得我们的方案接近最优的内存使用率,仅有不到4%的浪费。通过内存效率的提升,我们能够显著提升BatchSize,同时进行多个序列的推理,提高GPU利用率,从而显著提高吞吐量。
PagedAttention:Cache在物理上不必连续
使用 PagedAttention 的请求的示例生成过程
在并行采样中,从相同的提示生成多个输出序列。在这种情况下,可以在输出序列之间共享提示的计算和内存。通过其块表,PagedAttention能够自然地实现内存共享。类似于进程共享物理页,PagedAttention中的不同序列可以通过将它们的逻辑块映射到相同的物理块来共享块。为确保安全共享,PagedAttention跟踪物理块的引用计数并实现 Copy-on-Write 机制。
通过PagedAttention的内存共享机制,极大地降低了复杂采样算法(如ParallelSampling和BeamSearch)的内存开销,使其内存使用量下降了高达55%。这项优化可以直接带来最多2.2倍的吞吐量提升,从而使得LLM服务中使用这些采样方法变得更加实用。
同时进行多输出的采样
多输出采样的物理展示
LightSeq支持BERT、GPT、Transformer、VAE 等众多模型,同时支持beam search、diverse beam search[5]、sampling等多种解码方式。下表详细列举了Faster Transformer[7]、Turbo Transformers[6]和LightSeq三种推理引擎在文本生成场景的功能差异:
这个工作大约是20年做的,部分转载自https://mp.weixin.qq.com/s/HUSYSrjG65p1TU9lS_KEUA,几个关键技术
蓝色部分是自定义核函数,黄色部分是矩阵乘法。可以发现,矩阵乘法之间的运算全部都用一个定制化核函数实现了,因此大大减少了核函数调用和显存读写,最终提升了运算速度。
此外LightSeq还实现了词嵌入层和损失函数层的算子融合。对于词嵌入层,LightSeq将词表查找与放缩、位置向量融合以及dropout操作都写成了一个核函数。对于损失函数层,将交叉熵损失融合成一个核函数。通过输入输出层的融合,进一步减小了模型训练的时间,增加了显卡利用率。
在融合之前一个词嵌入层需要经过词向量查找与放缩、位置向量查找、两者相加、dropout五种运算,因此需要频繁调用核函数,非常耗时。而将这五个操作融合成一个核函数可以大大加快获取最终词表示的速度
为了避免计算过程中的显存申请释放并节省显存占用,LightSeq 首先对模型中所有动态的 shape 都定义了最大值(例如最大序列长度),将所有动态shape转换为静态。接着在服务启动的时候,为计算过程中的每个中间计算结果按最大值分配显存,并对没有依赖的中间结果共用显存。这样对每个请求,模型推理时不再申请显存,做到了:不同请求的相同 Tensor 复用显存;同请求的不同 Tensor 按 shape 及依赖关系复用显存。
通过该显存复用策略,在一张 T4 显卡上,LightSeq 可以同时部署多达 8 个 Transformer big 模型(batch_size=8,最大序列长度=8,beam_size=4,vocab_size=3万)。从而在低频或错峰等场景下,大大提升显卡利用率。
在自回归序列生成场景中,最复杂且耗时的部分就是解码。LightSeq 目前已经支持了 beam search、diversity beam search、top-k/top-p sampling 等多种解码方法,并且可以配合 Transformer、GPT使用,达到数倍加速。这里我们以应用最多的 beam search 为例,介绍一下 LightSeq 对解码过程的优化。
首先来看下在深度学习框架中传统是如何进行一步解码计算的:
1.计算以每个token为结尾的序列的log probability
log_token_prob = tf.nn.log_softmax(logit) # [batch_size, beam_size, vocab_size]
log_seq_prob += log_token_prob # [batch_size, beam_size, vocab_size]
log_seq_prob = tf.reshape(log_seq_prob, [-1, beam_size * vocab_size])
topk_log_probs, topk_indices = tf.nn.top_k(log_seq_prob, k=K)
refresh_cache(cache, topk_indices)
可以发现,为了挑选概率 top-k 的 token ,必须在 [batch_size, beam_size, vocab_size]大小的 logit 矩阵上进行 softmax 计算及显存读写,然后进行 batch_size 次排序。通常 vocab_size 都是在几万规模,因此计算量非常庞大,而且这仅仅只是一步解码的计算消耗。因此实践中也可以发现,解码模块在自回归序列生成任务中,累计延迟占比很高(超过 30%)。
LightSeq 的创新点在于结合 GPU 计算特性,借鉴搜索推荐中常用的粗选-精排的两段式策略,将解码计算改写成层级式,设计了一个 logit 粗选核函数,成功避免了 softmax 的计算及对十几万元素的排序。该粗选核函数遍历 logit 矩阵两次:
• 第一次遍历,对每个 beam,将其 logit 值随机分成k组,每组求最大值,然后对这k个最大值求一个最小值,作为一个近似的top-k值(一定小于等于真实top-k值),记为R-top-k。在遍历过程中,同时可以计算该beam中logit的log_sum_exp值。
• 第二次遍历,对每个 beam,找出所有大于等于 R-top-k 的 logit 值,将(logit - log_sum_exp + batch_id * offset, beam_id * vocab_size + vocab_id)写入候选队列,其中 offset 是 logit 的下界。
在第一次遍历中,logit 值通常服从正态分布,因此算出的R-top-k值非常接近真实top-k值。同时因为这一步只涉及到寄存器的读写,且算法复杂度低,因此可以快速执行完成(十几个指令周期)。实际观察发现,在top-4设置下,根据R-top-k只会从几万token中粗选出十几个候选,因此非常高效。第二次遍历中,根据R-top-k粗选出候选,同时对 logit 值按 batch_id 做了值偏移,多线程并发写入显存中的候选队列。
粗选完成后,在候选队列中进行一次排序,就能得到整个batch中每个序列的准确top-k值,然后更新缓存,一步解码过程就快速执行完成了。
下面是k=2,词表大小=8的情况下一个具体的示例(列代表第几个字符输出,行代表每个位置的候选)。可以看出,原来需要对 16 个元素进行排序,而采用层级解码之后,最后只需要对 5 个元素排序即可,大大降低了排序的复杂度。
几个关键技术:
这个算法源自字节跳动 AML 团队之前的工作 “effective Transformer”,在 NVIDIA 开源 FasterTransformer 中也有集成。ByteTransformer 同样使用该算法去除对 attention 外矩阵乘的额外计算。
算法步骤:
为了优化 attention 部分的性能,ByteTransformer 中实现了 fused multi-head attention 算子。对于 seqlen 长度,以 384 为界划分为两种实现方式:
NVIDIA 开发的 grouped GEMM 可以在一个 kernel 中完成多个独立矩阵乘问题的计算,利用这个性质可以实现 Attention 中的 padding free。
grouped GEMM 原理:kernel 中每个 threadblock (CTA) 固定 tiling size,每个矩阵乘子问题根据 problem size 和 tiling size,拆解为不同数量的待计算块,再把这些块平均分配到每个 threadblock 中进行计算。
使用 grouped GEMM 实现 attention 时,由于子问题的数量 batch_size x head_num 通常较大,读取子问题参数会有不小的开销,因为从线程角度看,每个线程都需要遍历读取所有的子问题大小。
为了解决这个问题,ByteTransformer 对 grouped GEMM 中读取子问题参数进行了性能优化,使其可以忽略不计:
这部分和FlashAttention其实有点像,但
FlashAttention是用一个threadblock循环计算不同的块,然后不断更新结果。ByteTransformer是
多个threadblock独立计算结果后,再用一个小kernel聚合。FlashAttention是在一个fused kernel中完成所有计算的,而ByteTransformer是拆成了两个矩阵乘的kernel。在batchsizexhead_num小的情况下,flash会因为threadblock并发度不足打满硬件sm数量,性能较差,但除此之外因为fuse的更彻底,flash性能会更好一点
为了进一步提高性能,把 Q x K 之后的 softmax 也 fuse 到矩阵乘算子中,相比单独的 softmax kernel 节省了中间矩阵的访存操作。
因为 softmax 需要对整行数据做归约,但因为共享内存大小的限制,一个 threadblock 内不能容纳整行数据,同时 threadblock 间的通信很低效,所以不能仅在 Q x K 的 epilogue 中完成整个 softmax 的操作。把 softmax 拆分成三步计算,分别 fuse 到 Q x K 的 epilogue 中, QK x V 的 prologue 中,以及中间再添加一个轻量的 kernel 做规约。
除矩阵乘和 attention 的优化外,ByteTransformer 还对一些小的操作进行了全面的 kernel fusion,通过减少显存访问和 kernel launch 的开销,可以获得更极致的性能。
矩阵乘之后的 add-bias 和 LayerNorm 操作,通过手写 kernel 的方式做 fusion,这部分操作在 seqlen 为 256 和 1024 的情况下分别占 10% 和 6% 的延迟,fused kernel 可以优化 61% 的性能,对单层 BERT Transformer 的性能提升 3.2%(平均 seqlen 128 - 1024 的情况)。
通过 CUTLASS fuse epilogue 的方式,把矩阵乘后的 add-bias 操作和 GELU activation 操作 fuse 到矩阵乘 kernel 中。add-bias 和 GELU 在 seqlen 为 256 和 1024 的情况下占比耗时分别为 7% 和 5%。把 add-bias 和 GELU 融合 GEMM 可以完美隐藏这部分的访存延迟,进一步使单层 transformer 性能提升 3.8%。
部分引用自:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。