赞
踩
LLM 推理任务需要大量的算力,将现代 GPU 推向极限。过去两年, LLM 训练和推理优化相关的研究进展速度惊人,每六个月就会出现新的突破。
今天的分享主要以Llama3为例,为大家介绍LLM 推理领域所必备的一些基本数学与概念,这包含了Llama3的模型结构、参数量以及推理显存的计算,对于张量、矩阵等基本数学原理本文不会赘述。
我们先来看看标准的Transformer-decoder架构。
左图来自于大模型技术的起源之作《Attention is all you need》,这是一个标准的transformer结构,由左侧的encoder和右侧的decoder两部分构成。
如果我们将encoder部分去掉(左图中橙色掩盖部分),那么就成了我们目前以GPT为主流的decoder-only的大模型结构(右图)。
Transformer-decoder架构的模型主要包括embedding层、transformer-decoder、output层。
其中transformer-decoder由多个decoder层堆叠而成。
参数量的计算
#embedding层
embedding层将用户的输入映射到向量空间,在这层有一个词表矩阵W,它的维度是[v,d_model],v是词表长度vocab_size,d_model是词向量维度,一般和隐层维度hidden_size一致。
因此,输入层参数量为:v*d_model
#output层
output层需要将隐层维度转换到词表中每个词的概率,即将隐层维度d_model转换为v,因此会有一个[d_model,v]的矩阵。
有的模型会复用词表矩阵W,直接使用W的转置矩阵,比如gpt-2,此时不会增加额外的参数,而有的则不会,比如llama3,此时,输出层的参数量为d_model*v
#transformer-decoder
transformer-decoder采用了多个decoder层堆叠而成,每个decoder层结构相同,参数不同。
而对于每个decoder层,最主要的操作包括:Attention**、FFN,**当然除此之外还有LayerNorm。
(1)Attention的结构和参数量
在原始的transformer结构中采用的是MHA(Multi-Head Attention,多头注意力),而在Llama3中使用的是GQA(Grouped-Query Attention)的变种,下面进行分别介绍。
Multi-Head Attention
根据MHA的计算公式,在每一个head中,Q、K、V要先进行线性变换,然后进行Attention计算,最后把所有head拼接起来后再进行一次线性变换得到最终Attention部分的结果。
在这部分的计算中,可以看出其参数包括每一个head中Q、K、V的权重矩阵W_Q、W_K、W_V,这三个矩阵的维度均为[d_model, d],以及最终结果的变换矩阵W_O,其维度为[d_model,d_model]。
我们知道,在head"拆分"的时候,为了使拼接后的维度与之前保持一致,d=d_model/h
因此,MHA的参数量为:3*h*d_model*d+d_model*d_model = 4*d_model*d_model
Grouped-Query Attention
在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。
因此 Noam Shazeer 提出改造,将原来的h个KV对缩短为1个,所有query只使用一个共享的KV对,即MQA(Multi-Query Attention),这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。
此后,他又提出GQA,将query进行分组,每组query共享一个KV对,在降低显存的同时保证模型性能。
在Llama3-8B中,原始的MHA的heads为32,GQA的heads为8,KV对的参数量变为原来的1/4,因此,Llama3中的Attention结构的参数量为d_model*d_model+2*d_model*d/4+d_model*d_model
(2)FFN的结构和参数量
在原始的transformer结构中采用的是常见的前馈神经网络(FFN,也称MLP),而在Llama3中使用的是SwiGLU的变种,下面进行分别介绍。
原始FFN
在标准的Transformer中, FFN层是一个典型的标准多层网络结构,由两个全连接层组成,先向上升维到原来的4倍,经过一个Relu激活函数,再向下降维到原始维度。
因此,原始FNN参数量为d_model*d_ffn+d_ffn*d_model = 8*d_model*d_model
左图:原始FFN;右图:SwiGLU变体FFN
SwiGLU
SwiGLU也是由Noam Shazeer提出的一个变种,主要引入了GLU和Swish激活函数。
在 SwiGLU变体中,第一个全连接层和激活函数被替换成了GLU和swish激活函数。SwiGLU采用了两个权重矩阵分别对输入进行升维变换,再通过Swish激活函数做哈达玛乘积操作,最后再进行维度还原。
可以看出,在SwiGLU变体的FFN模块中,实际有三个权重矩阵,在不同的模型中,大家会优化d_inter的大小(intermediate_size)。
因此,带有SwiGLU变体FFN的参数量为2*d_model*d_inter+d_model*d_inter = 3d_model*d_inter
#汇总计算
以Llama3-8B为例,在Llama3-8B中一共由32层这样的decoder层堆叠而成,因此,对所有的参数进行汇总如下:
Total_parameters = 525336576+32*(16777216+4194304+4194304+16777216+58720256+58720256+58720256)+525336576
=8,029,995,008~=8B
在上面的计算中我们没有考虑LayerNorm层的参数(Llama3中采用的是改进的RMSNorm),因为LayerNorm层参数量较少。
在decoder结构中有两次LayerNorm操作,每层参数量为4096,如果加上这部分参数,总参数量为8,029,995,008+32*2*4096=8,030,257,152,这也正好对上了8B的参数量
推理显存量计算
推理的显存一般由以下几个部分构成:
#模型参数
模型中的每个参数都需要存储,因此显存占用为:
以Llama3-8B模型为例,使用fp16的数据类型存储,对应的显存量约为:
#KV Cache
我们知道,在模型推理过程中,模型一次生成一个token,然后使用之前生成的token作为输入来预测下一个token。
每次生成新的token时,模型需要重新计算新的Q、K、V,并基于它们计算Attention权重。然而,之前生成的K、V在当前解码过程中是可以重复利用的,为了加快推理速度,可以将之前计算好的K、V存储在缓存中,这就是KV Cache,它们存储在GPU显存中,从而节省计算时间。
KV Cache的推理包含两个阶段:
**prefill阶段:**在生成第一个token时,模型计算每个Transformer层的K、V,并存储到缓存中。
decode阶段:在此后生成的每一个token,由于KV Cache已经缓存了之前所有轮次的K、V,每轮推理只需从Cache中读取数据,并将新计算的Key和Value更新到Cache中
这部分的显存,大概可以这样计算:
B: batch_size
2: 代表K、V均需存储
l: decoder层数
h: 每层注意力头数
d: 单头维度
s: 句子序列长度
以Llama3-8B模型为例,对于一条长度为2048 token的序列,使用fp16的数据类型存储KV cache,对应的峰值显存量约为:
#中间激活值
中间激活值(Intermediate Activations)指模型在各层之间传递的中间结果。这些激活值反映了输入数据在每层网络中的特征表示,在模型在推理过程中,也涉及到这些中间激活值的存储。
c: 中间变量个数,取决于具体模型结构和推理策略,本文不做详细介绍,用常量c表示
#汇总计算
总体来说,在推理过程中,模型权重和 KV cache约占 GPU 显存总需求的 90%。
而且,从我们上面的公式可以看出:模型参数量和序列长度无关,但是 KV cache和序列的长度成正比。随着序列的增长而快速线性增长。当我们处理一条128k长度的请求时,按照公式计算,Llama3-8B模型的KV cache缓存将需要67.1G,已经远远超过了模型参数所需的显存量,因此我们在进行一些显卡配置评估时,除了模型本身,输入输出的序列长度也是一个重要的因素,超长上下文很香,但是硬件资源的消耗也非常大。
当然,在推理领域还有大量的研究和改进技术,比如Flash Attention、Paged Attention、量化、连续批处理等等来优化推理内存和速度,感兴趣的同学可以深入了解。
附上本文的一些学习资料
**【Llama3模型结构源码】**https://github.com/meta-llama/llama3/blob/main/llama/model.py
**【GPT2模型结构源码】**https://github.com/openai/gpt-2/blob/master/src/model.py
**【Transformer论文链接】**https://arxiv.org/pdf/1706.03762
**【GPT论文链接】**https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf
**【GQA论文链接】**https://arxiv.org/pdf/2305.13245
**【SwiGLU论文链接】**https://arxiv.org/pdf/2002.05202
**【Llama3技术报告】**https://scontent-sjc3-1.xx.fbcdn.net/v/t39.2365-6/453304228_1160109801904614_7143520450792086005_n.pdf?_nc_cat=108&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=ivumNlLcSY8Q7kNvgEZ22Pr&_nc_ht=scontent-sjc3-1.xx&oh=00_AYAHSJTq5zBtLRF5pOS5Getm2cMUSv6RvePUUKImGvHrlw&oe=66BA1007
本次关于大模型参数量和推理显存的计算就到这里了, 我们以llama3为例,介绍了其模型结构、参数量以及推理显存的详细计算。希望你喜欢这份文章,并从中学到一些新知识!如果你对大模型有更多的兴趣,欢迎继续关注我们。感谢你的阅读!
由于新岗位的生产效率,要优于被取代岗位的生产效率,所以实际上整个社会的生产效率是提升的。
但是具体到个人,只能说是:
“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。
这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。
保证100%免费
】我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。
我意识到有很多经验和知识值得分享给大家,也可以通过我们的能力和经验解答大家在人工智能学习中的很多困惑,所以在工作繁忙的情况下还是坚持各种整理和分享。但苦于知识传播途径有限,很多互联网行业朋友无法获得正确的资料得到学习提升,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。
该阶段让大家对大模型 AI有一个最前沿的认识,对大模型 AI 的理解超过 95% 的人,可以在相关讨论时发表高级、不跟风、又接地气的见解,别人只会和 AI 聊天,而你能调教 AI,并能用代码将大模型和业务衔接。
该阶段我们正式进入大模型 AI 进阶实战学习,学会构造私有知识库,扩展 AI 的能力。快速开发一个完整的基于 agent 对话机器人。掌握功能最强的大模型开发框架,抓住最新的技术进展,适合 Python 和 JavaScript 程序员。
恭喜你,如果学到这里,你基本可以找到一份大模型 AI相关的工作,自己也能训练 GPT 了!通过微调,训练自己的垂直大模型,能独立训练开源多模态大模型,掌握更多技术方案。
到此为止,大概2个月的时间。你已经成为了一名“AI小子”。那么你还想往下探索吗?
对全球大模型从性能、吞吐量、成本等方面有一定的认知,可以在云端和本地等多种环境下部署大模型,找到适合自己的项目/创业方向,做一名被 AI 武装的产品经理。
学习是一个过程,只要学习就会有挑战。天道酬勤,你越努力,就会成为越优秀的自己。
如果你能在15天内完成所有的任务,那你堪称天才。然而,如果你能完成 60-70% 的内容,你就已经开始具备成为一名大模型 AI 的正确特征了。
保证100%免费
】Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。