赞
踩
近两年大模型火出天际;同时,也诞生了大量针对大模型的优化技术。本系列将针对一些常见大模型优化技术进行讲解。
而本文将针对仅编码器Transformer架构(Decoder-Only Transformer)的模型必备显存优化技术 KV Cache 进行讲解。
KV Cache 是大模型推理性能优化的一个常用技术,该技术可以在不影响任何计算精度的前提下,通过空间换时间的思想,提高推理性能。
对于仅编码器Transformer架构的模型的推理,我们给一个输入文本,模型会输出一个回答(长度为 N),其实该过程中执行了 N 次推理过程。即类 GPT 的仅编码器模型一次推理只输出一个token,输出的 token 会与输入 tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。
针对一个仅编码器Transformer架构的模型,假设用户输入为“recite the first law”,模型续写得到的输出为“A robot may not ”,模型的生成过程如下:
仅编码器Transformer架构的自回归模型为带 Masked 的 Self Attention。因此,在没有KV Cache的情况下,其计算过程如下所示。
正常情况下,Attention的计算公式如下:
为了看上去方便,我们暂时忽略scale项,因此,Attention的计算公式如下所示(softmaxed 表示已经按行进行了softmax):
当QKTQK^TQKT变为矩阵时,softmax 会针对行进行计算,详细如下(softmaxed 表示已经按行进行了softmax):
其中,Att1(Q,K,V)Att_1(Q,K,V)Att1(Q,K,V)表示 Attention 的第一行, Att2(Q,K,V)Att_2(Q,K,V)Att2(Q,K,V)表示 Attention 的第二行。
对于Att1(Q,K,V)Att_1(Q,K,V)Att1(Q,K,V),由于Q1K2TQ_1K_2^TQ1K2T这个值会mask掉,你会发现,Q1Q_1Q1 在第二步参与的计算与第一步是完全一样的,并且 V1V_1V1 参与计算Attention时也仅仅依赖于 Q1Q_1Q1 ,与 Q2Q_2Q2 毫无关系。
对于Att2(Q,K,V)Att_2(Q,K,V)Att2(Q,K,V),V2V_2V2 参与计算Attention时也仅仅依赖于Q2Q_2Q2 ,与 Q1Q_1Q1 毫无关系。
其计算方式如 Step2 所示。
其计算方式如 Step2 所示。
对于Attk(Q,K,V)Att_k(Q,K,V)Attk(Q,K,V), VkV_kVk 参与计算Attention时也仅仅依赖于 QkQ_kQk。
看上面图和公式,我们可以得出以下结论:
正是因为 Self Attention 中带 Masked ,因此,在推理的时候,前面已经生成的 Token 不需要与后面的 Token 产生 Attention ,从而使得前面已经计算的 K 和 V 可以缓存起来。
一个典型的带有 KV cache 优化的生成大模型的推理过程包含了两个阶段:
预填充阶段:输入一个prompt序列,为每个transformer层生成 key cache 和 value cache(KV cache)。
解码阶段:使用并更新KV cache,一个接一个地生成token,当前生成的token词依赖于之前已经生成的token。
预填充阶段计算过程如下:
解码阶段计算过程如下:
下图展示了使用KV Cache和不使用KV Cache的对比,其中,紫色部分表示从缓存获取,灰色部分表示会被Masked。
下面使用 transformers 来比较有 KV Cache 和没有 KV Cache的情况下,GPT-2的生成速度。
import numpy as np import time import torch from transformers import AutoModelForCausalLM, AutoTokenizer device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained("gpt2") model = AutoModelForCausalLM.from_pretrained("gpt2").to(device) for use_cache in (True, False): times = [] for _ in range(10): # measuring 10 generations start = time.time() model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000) times.append(time.time() - start) print(f"{'with' if use_cache else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")
运行结果:
可以看到使不使用 KV cache 推理性能果差异显存。
FLOPs,floating point operations,表示浮点数运算次数,衡量了计算量的大小。
如何计算矩阵乘法的FLOPs呢?
对于 声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。