赞
踩
关于位置编码和RoPE
考虑到只要花足够多的时间 心思 投入,没有写不清楚的,讲课更是如此,故为彻底解决这个位置编码/RoPE的问题,我把另外两篇文章中关于位置编码的内容抽取出来,并不断深入、扩展、深入,比如其中最关键的改进是两轮改进,一个12.16那天,一个12.21那天
最终成为本文
如此篇文章《Transformer通俗笔记:从Word2Vec、Seq2Seq逐步理解到GPT、BERT》所述,RNN的结构包含了序列的时序信息,而Transformer却完全把时序信息给丢掉了,比如“他欠我100万”,和“我欠他100万”,两者的意思千差万别,故为了解决时序的问题,Transformer的作者用了一个绝妙的办法:位置编码(Positional Encoding)
即将每个位置编号,从而每个编号对应一个向量,最终通过结合位置向量和词向量,作为输入embedding,就给每个词都引入了一定的位置信息,这样Attention就可以分辨出不同位置的词了,具体怎么做呢?
至于是embedding向量的位置下标对2求商并取整(可用双斜杠表示整数除法,即求商并取整),它的取值范围是,比如
位置向量的第多少维 (0 2 4等偶数维用sin函数计算) | |||
0 | |||
1 | |||
2 | |||
3 | |||
4 | |||
5 | |||
6 | |||
.... | |||
510 | |||
511 |
不要小看transformer的这个位置编码,不少做NLP多年的人也不一定对其中的细节有多深入,而网上大部分文章谈到这个位置编码时基本都是千篇一律、泛泛而谈,很少有深入,故本文还是细致探讨下
考虑到一图胜千言 一例胜万语,举个例子,当我们要编码「我 爱 你」的位置向量,假定每个token都具备512维,如果位置下标从0开始时,则根据位置编码的计算公式可得『且为让每个读者阅读本文时一目了然,我计算了每个单词对应的位置编码示例(在此之前,这些示例在其他地方基本没有)』
然后再叠加上embedding向量,可得
最终得到的可视化效果如下图所示
代码实现如下
- “”“位置编码的实现,调用父类nn.Module的构造函数”“”
- class PositionalEncoding(nn.Module):
- def __init__(self, d_model, dropout, max_len=5000):
- super(PositionalEncoding, self).__init__()
- self.dropout = nn.Dropout(p=dropout) # 初始化dropout层
-
- # 计算位置编码并将其存储在pe张量中
- pe = torch.zeros(max_len, d_model) # 创建一个max_len x d_model的全零张量
- position = torch.arange(0, max_len).unsqueeze(1) # 生成0到max_len-1的整数序列,并添加一个维度
- # 计算div_term,用于缩放不同位置的正弦和余弦函数
- div_term = torch.exp(torch.arange(0, d_model, 2) *
- -(math.log(10000.0) / d_model))
-
- # 使用正弦和余弦函数生成位置编码,对于d_model的偶数索引,使用正弦函数;对于奇数索引,使用余弦函数。
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0) # 在第一个维度添加一个维度,以便进行批处理
- self.register_buffer('pe', pe) # 将位置编码张量注册为缓冲区,以便在不同设备之间传输模型时保持其状态
-
- # 定义前向传播函数
- def forward(self, x):
- # 将输入x与对应的位置编码相加
- x = x + Variable(self.pe[:, :x.size(1)],
- requires_grad=False)
- # 应用dropout层并返回结果
- return self.dropout(x)
本文发布之后,有同学留言问,上面中的第11行、12行代码
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
为什么先转换为了等价的指数+对数运算,而不是直接幂运算?是效率、精度方面有差异吗?
这里使用指数和对数运算的原因是为了确保数值稳定性和计算效率
所以,使用指数和对数运算可以在保持数值稳定性的同时提高计算效率。
既然提到了这行代码,我们干脆就再讲更细致些,上面那行代码对应的公式为
其中的中括号对应的是一个从 0 到 的等差数列(步长为 2),设为
且上述公式与这个公式是等价的
为何,原因在于,从而有
最终,再通过下面这两行代码完美实现位置编码
- # 使用正弦和余弦函数生成位置编码,对于d_model的偶数索引,使用正弦函数;对于奇数索引,使用余弦函数。
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
先复习下复数的一些关键概念
在我们的日常生活中,经常会遇到各种平移运动,为了描述这些平移运动,数学上定义了加减乘除,然还有一类运动是旋转运动,而加减乘除无法去描述旋转运动,而有了复数之后,便不一样了,此话怎讲?
根据复数的定义:,可以看出来:,而这个展开过程就揭示了虚数 背后的本质,因为这个展开过程中的两次乘法可以看成连续的操作
so, 就代表了旋转(至此,可能你已经隐隐约约意识到,为何我们在解释旋转位置编码时,为何要扯上复数了),为形象说明,再举两个例子
当 表示任意实数, 是自然对数的底数, 是复数中的虚数单位,则根据欧拉公式有
表达的含义在于该指数函数可以表示为实部为,虚部为的一个复数
该欧拉公式相当于建立了指数函数、三角函数和复数之间的桥梁,但怎么推导出来的呢,其实很简单
- 由于有
- 所以,如果 ,则有
如何直观的理解这个欧拉公式呢?
其实,可以把看作通过单位圆的圆周运动来描述单位圆上的点,通过复平面的坐标来描述单位圆上的点,是同一个点不同的描述方式,所以有,如下图所示
根据欧拉公式,可以轻易推出:
我们把复数当作向量来看待,复数的实部是方向,虚部是方向,很容易观察出其几何意义,如下图所示
还在思考怎么得来的?很简单哦,还记得向量的加减法么?
所谓旋转位置编码,其在位置编码上删除了绝对位置嵌入,而在网络的每一层增加了苏剑林等人(2021)提出的旋转位置嵌入(RoPE),其思想是采用绝对位置编码的形式 实现相对位置编码,且RoPE主要借助了复数的思想
具体来说,当咱们给self-attention中的向量都加入了位置信息后,便可以表示为
其中
接着论文中提出为了能利用上 token 之间的相对位置信息,假定 query 向量 和 key 向量 之间的内积操作可以被一个函数 表示,该函数 的输入是词嵌入向量 、 ,和它们之间的相对位置 :
这里面其实有很大的一个关键,但大部分资料甚至RoPE原始论文都不会给你特别强调出来,即为何要构造这么一个等式呢?
- 原因在于左边算是q和k向量的内积,而这恰好是transformer计算自注意力机制的核心一步,右边等式则意味着m与n的相对位置
如此一来,该等式便把“q和k的内积”与“它们的相对位置”给串起来了- 也如阿荀所说,左边是含有各自绝对位置信息的q向量和k向量,而这个等式就是RoPE追求的目标,物理含义就是通过显式传入绝对位置信息实现与传入相对位置信息对等的情况
假定现在词嵌入向量的维度是两维 ,然后RoPE利用2维度平面上的向量的几何性质,再结合复数的性质,神奇般的找到了满足上述等式的 和 ,其形式如下:
这里面的 Re 表示复数的实部
然上述分别关于、、的三个式子,咋一步一步推导来的?为做细致说明,特参考此文一步一步解释下
首先看第一个式子,对于,这个式子的右边项有两部分,一部分是、一部分是
- 对于前者,可知其中的是个二维矩阵,是个二维向量,自然相乘的结果也必然是一个二维向量,用表示
- 对于后者,根据欧拉公式,可得
- 基于上面第1点结论,可知
然后将表示成复数形式,可得
从而有
基于上面第2点结论,可知即是两个复数相乘- 考虑到以下两个关于复数的背景知识
可得
将这个结果表达成实数向量形式,即是
至此,你也就不难发现,这不就是query向量乘以了一个旋转矩阵么
fq(xm,m)=(Wqxm)eimθ=qmeimθ=[q(1)mcos(mθ)−q(2)msin(mθ),q(2)mcos(mθ)+q(1)msin(mθ)]=(cos(mθ)−sin(mθ)sin(mθ)cos(mθ))(q(1)mq(2)m)至于第二个式子,根据上述过程同理,可得key向量
最后第三个式子,函数g,则可得
- g(xm,xn,m−n)=Re[(Wqxm)(Wkxn)∗ei(m−n)θ]
其中,Re[x]表示一个复数的实数部分,而(Wkxn)∗则表示复数Wkxn的共轭
- 考虑到
再结合上面第一个式子中的推导,可得
Wqxm=qm=q(1)m+iq(2)mWkxn=kn=k(1)n+ik(2)n(Wkxn)∗=k∗n=k(1)n−ik(2)nei(m−n)θ=cos((m−n)θ)+isin((m−n)θ)
继续结合上面第一个式子中的推导(比如,及),继续可知,我们现在要证明的是存在
g(xm,xn,m−n)=Re[(Wqxm)(Wkxn)∗ei(m−n)θ]=Re[(q(1)m+iq(2)m)(k(1)n−ik(2)n)(cos((m−n)θ)+isin((m−n)θ))]=Re[((q(1)mk(1)n+q(2)mk(2)n)+i(q(2)mk(1)n−q(1)mk(2)n))(cos((m−n)θ)+isin((m−n)θ))]=(q(1)mk(1)n+q(2)mk(2)n)cos((m−n)θ)−(q(2)mk(1)n−q(1)mk(2)n)sin((m−n)θ)- 总之,接下来我们就要证明上述函数 g 的计算公式是成立的
首先,回顾一下attention操作,位置m的query和位置n的key会做一个内积操作
即由
fq(xm,m)=[q(1)mcos(mθ)−q(2)msin(mθ),q(2)mcos(mθ)+q(1)msin(mθ)]fk(xn,n)=[k(1)ncos(nθ)−k(2)nsin(nθ),k(2)ncos(nθ)+k(1)nsin(nθ)]
可得
<fq(xm,m),fk(xn,n)>=(q(1)mcos(mθ)−q(2)msin(mθ))(k(1)ncos(nθ)−k(2)nsin(nθ))+(q(2)mcos(mθ)+q(1)msin(mθ))(k(2)ncos(nθ)+k(1)nsin(nθ))=q(1)mcos(mθ)k(1)ncos(nθ)−q(1)mcos(mθ)k(2)nsin(nθ)−q(2)msin(mθ)k(1)ncos(nθ)+q(2)msin(mθ)k(2)nsin(nθ)+q(2)mcos(mθ)k(2)ncos(nθ)+q(2)mcos(mθ)k(1)nsin(nθ)+q(1)msin(mθ)k(2)ncos(nθ)+q(1)msin(mθ)k(1)nsin(nθ)
「相当于[A,B]与[C,D]做内积,则相当于A B横着,C D竖着,最终结果为AC BD,最后再把括号里的项全部对应相乘、展开」- 首先,把上面第二点的式子整理一下,总计8项,为了把qk相关的项提取出来,第1项 8项合并处理、第2项 7项合并处理、第3项 6项合并处理、第4项 5项合并处理
其次,考虑到
sin(a+b)=sinacosb+cosasinbsin(a−b)=sinacosb−cosasinbcos(a+b)=cosacosb−sinasinbcos(a−b)=cosacosb+sinasinb
最后,再把相关项的特点,两次调整下顺序即可
依据以上三点,从而有
<fq(xm,m),fk(xn,n)>=q(1)mk(1)n(cos(mθ)cos(nθ)+sin(mθ)sin(nθ))+q(1)mk(2)n(−cos(mθ)sin(nθ)+sin(mθ)cos(nθ))+q(2)mk(1)n(−sin(mθ)cos(nθ)+cos(mθ)sin(nθ))+q(2)mk(2)n(sin(mθ)sin(nθ)+cos(mθ)cos(nθ))=q(1)mk(1)ncos((m−n)θ)+q(1)mk(2)nsin((m−n)θ)−q(2)mk(1)nsin((m−n)θ)+q(2)mk(2)ncos((m−n)θ)=(q(1)mk(1)n+q(2)mk(2)n)cos((m−n)θ)+(q(1)mk(2)n−q(2)mk(1)n)sin((m−n)θ)=(q(1)mk(1)n+q(2)mk(2)n)cos((m−n)θ)−(q(2)mk(1)n−q(1)mk(2)n)sin((m−n)θ)=g(xm,xn,m−n)
完美! 如此,也就证明了,位置 m 的 query 和位置 n 的 key 的内积就是函数 g最后,把上面的式子一、式子二的最终结果都分别用矩阵向量乘的形式来表达就是:
<fq(xm,m),fk(xn,n)>=((cos(mθ)−sin(mθ)sin(mθ)cos(mθ))(q(1)mq(2)m))T((cos(nθ)−sin(nθ)sin(nθ)cos(nθ))(k(1)nk(2)n))=(q(1)mq(2)m)(cos(mθ)sin(mθ)−sin(mθ)cos(mθ))(cos(nθ)−sin(nθ)sin(nθ)cos(nθ))(k(1)nk(2)n)
接下来,我们要计算两个旋转矩阵的乘积,即中间部分的这个式子
(cos(mθ)sin(mθ)−sin(mθ)cos(mθ))(cos(nθ)−sin(nθ)sin(nθ)cos(nθ))
展开之后,可得
(cos(mθ)cos(nθ)+sin(mθ)sin(nθ)−cos(mθ)sin(nθ)+sin(mθ)cos(nθ)−sin(mθ)cos(nθ)+cos(mθ)sin(nθ)sin(mθ)sin(nθ)+cos(mθ)cos(nθ))
从而有
<fq(xm,m),fk(xn,n)>=(q(1)mq(2)m)(cos((m−n)θ)−sin((m−n)θ)sin((m−n)θ)cos((m−n)θ))(k(1)nk(2)n)
上面都还只是针对词嵌入维度为2的情况,那对于d>=2的通用情况呢,将2维推广到任意维度,可以表示如下:
f{q,k}(xm,m)=RdΘ,mW{q,k}xm
内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接,即将词嵌入向量元素按照两两一组分组
RdΘ,m=(cosmθ0−sinmθ000⋯00sinmθ0cosmθ000⋯0000cosmθ1−sinmθ1⋯0000sinmθ1cosmθ1⋯00⋮⋮⋮⋮⋱⋮⋮0000⋯cosmθd/2−1−sinmθd/2−10000⋯sinmθd/2−1cosmθd/2−1)⏟Wm
每组应用同样的旋转操作且每组的旋转角度计算方式如下:
Θ={θi=10000−2(i−1)/d,i∈[1,2,…,d/2]}
所以简单来说 RoPE 的 self-attention 操作的流程是
与上面第一种形式的推导类似,为了引入复数,首先假设了在加入位置信息之前,原有的编码向量是二维行向量和,其中和是绝对位置,现在需要构造一个变换,将和引入到和中,即寻找变换:
~qm=f(q,m),~kn=f(k,n)
也就是说,我们分别为、设计操作、,使得经过该操作后,、就带有了位置、的绝对位置信息
考虑到Attention的核心计算是内积:
Attention(Q,K,V)=softmax(QKT√dk)V
故我们希望的内积的结果带有相对位置信息,即寻求的这个变换,应该具有特性:
「怎么理解?很简单,当m和n表示了绝对位置之后,m与n在句子中的距离即位置差m-n,就可以表示为相对位置了,且对于复数,内积通常定义为一个复数与另一个复数的共轭的乘积」
⟨qmeimθ,kneinθ⟩=Re[(qmeimθ)(kneinθ)∗]=Re[qmk∗nei(m−n)θ]
这样一来,内积的结果就只依赖于,也就是相对位置了于是,对于任意的位置为的二维向量[x,y],把它看做复数,乘以eimθ,而根据欧拉公式,有:
eimθ=cosmθ+isinmθ
从而上述的相乘变换也就变成了(过程中注意:i2=−1):
把上述式子写成矩阵形式:
而这个变换的几何意义,就是在二维坐标系下,对向量(q0,q1)进行了旋转,因而这种位置编码方法,被称为旋转位置编码
根据刚才的结论,结合内积的线性叠加性,可以将结论推广到高维的情形。可以理解为,每两个维度一组,进行了上述的“旋转”操作,然后再拼接在一起:
由于矩阵的稀疏性,会造成计算上的浪费,所以在计算时采用逐位相乘再相加的方式进行:
其中⊗为矩阵逐位相乘操作
原理理解了,接下来可以代码实现旋转位置编码,考虑到LLaMA本身的实现不是特别好理解,所以我们先通过一份非LLaMA实现的版本,最后再看下LLaMA实现的版本
对于,非LLaMA版的实现,其核心就是实现下面这三个函数 (再次强调,本份关于RoPE的非LLaMA版的实现 与上面和之后的代码并非一体的,仅为方便理解RoPE的实现)
sinusoidal_position_embedding:这个函数用来生成正弦形状的位置编码。这种编码用来在序列中的令牌中添加关于相对或绝对位置的信息
- def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
- # (max_len, 1)
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
-
- # (output_dim//2)
- # 即公式里的i, i的范围是 [0,d/2]
- ids = torch.arange(0, output_dim // 2, dtype=torch.float)
- theta = torch.pow(10000, -2 * ids / output_dim)
-
- # (max_len, output_dim//2)
- # 即公式里的:pos / (10000^(2i/d))
- embeddings = position * theta
-
- # (max_len, output_dim//2, 2)
- embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
-
- # (bs, head, max_len, output_dim//2, 2)
- # 在bs维度重复,其他维度都是1不重复
- embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))
-
- # (bs, head, max_len, output_dim)
- # reshape后就是:偶数sin, 奇数cos了
- embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
- embeddings = embeddings.to(device)
- return embeddings
一般的文章可能解释道这个程度基本就over了,但为了让初学者一目了然计,我还是再通过一个完整的示例,来一步步说明上述各个步骤都是怎么逐一结算的,整个过程和之前此文里介绍过的transformer的位置编码本质上是一回事..
为方便和transformer的位置编码做对比,故这里也假定output_dim = 512
,
,
,
,
,
,
...
,
,
ids = [0,0, 1,1, 2,2, ..., 254,254, 255,255]
- [
- [
- [
- [sin(\frac{0}{10000^{\frac{0}{512}}}), cos(\frac{0}{10000^{\frac{0}{512}}}), sin(\frac{0}{10000^{\frac{2}{512}}}), cos(\frac{0}{10000^{\frac{2}{512}}}), ..., cos(\frac{0}{10000^{\frac{510}{512}}})],
- [sin(\frac{1}{10000^{\frac{0}{512}}}), cos(\frac{1}{10000^{\frac{0}{512}}}), sin(\frac{1}{10000^{\frac{2}{512}}}), cos(\frac{1}{10000^{\frac{2}{512}}}), ..., cos(\frac{1}{10000^{\frac{510}{512}}})],
- [sin(\frac{2}{10000^{\frac{0}{512}}}), cos(\frac{2}{10000^{\frac{0}{512}}}), sin(\frac{2}{10000^{\frac{2}{512}}}), cos(\frac{2}{10000^{\frac{2}{512}}}), ..., cos(\frac{2}{10000^{\frac{510}{512}}})]
- ]
- ]
- ]
RoPE:这个函数将相对位置编码(RoPE)应用到注意力机制中的查询和键上。这样,模型就可以根据相对位置关注不同的位置
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import math
-
-
- def RoPE(q, k):
- # q,k: (bs, head, max_len, output_dim)
- batch_size = q.shape[0]
- nums_head = q.shape[1]
- max_len = q.shape[2]
- output_dim = q.shape[-1]
-
- # (bs, head, max_len, output_dim)
- pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)
-
-
- # cos_pos,sin_pos: (bs, head, max_len, output_dim)
- # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
- cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制
- sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制
-
- # q,k: (bs, head, max_len, output_dim)
- q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
- q2 = q2.reshape(q.shape) # reshape后就是正负交替了
-
- # 更新qw, *对应位置相乘
- q = q * cos_pos + q2 * sin_pos
-
- k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
- k2 = k2.reshape(k.shape)
- # 更新kw, *对应位置相乘
- k = k * cos_pos + k2 * sin_pos
-
- return q, k
老规矩,为一目了然起见,还是一步一步通过一个示例来加深理解
- # 注意,这只是一个简化的例子,真实的位置嵌入的值会有所不同。
- pos_emb = torch.tensor([[[[0.0000, 0.8415, 0.9093, 0.1411, 1.0000, 0.5403, -0.4161, -0.9900],
- [0.8415, 0.5403, 0.1411, -0.7568, 0.5403, -0.8415, -0.9900, -0.6536],
- [0.9093, -0.4161, -0.8415, -0.9589, -0.4161, -0.9093, -0.6536, 0.2836]]]])
- sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 提取出所有sin编码,并在最后一个维度上复制
- cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 提取出所有cos编码,并在最后一个维度上复制
- q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1).flatten(start_dim=-2)
- # q2: tensor([[[[-0.2, 0.1, -0.4, 0.3, -0.6, 0.5, -0.8, 0.7],
- # [-1.0, 0.9, -1.2, 1.1, -1.4, 1.3, -1.6, 1.5],
- # [-1.8, 1.7, -2.0, 1.9, -2.2, 2.1, -2.4, 2.3]]]])
-
- q = q * cos_pos + q2 * sin_pos
公式表示如下 - k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1).flatten(start_dim=-2)
- # k2: tensor([[[[-0.15, 0.05, -0.35, 0.25, -0.55, 0.45, -0.75, 0.65
attention:这是注意力机制的主要功能
- def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
- # q.shape: (bs, head, seq_len, dk)
- # k.shape: (bs, head, seq_len, dk)
- # v.shape: (bs, head, seq_len, dk)
-
- if use_RoPE:
- # 使用RoPE进行位置编码
- q, k = RoPE(q, k)
-
- d_k = k.size()[-1]
-
- # 计算注意力权重
- # (bs, head, seq_len, seq_len)
- att_logits = torch.matmul(q, k.transpose(-2, -1))
- att_logits /= math.sqrt(d_k)
-
- if mask is not None:
- # 对权重进行mask,将为0的部分设为负无穷大
- att_scores = att_logits.masked_fill(mask == 0, -1e-9)
-
- # 对权重进行softmax归一化
- # (bs, head, seq_len, seq_len)
- att_scores = F.softmax(att_logits, dim=-1)
-
- if dropout is not None:
- # 对权重进行dropout
- att_scores = dropout(att_scores)
-
- # 注意力权重与值的加权求和
- # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
- return torch.matmul(att_scores, v), att_scores
-
-
- if __name__ == '__main__':
- # (bs, head, seq_len, dk)
- q = torch.randn((8, 12, 10, 32))
- k = torch.randn((8, 12, 10, 32))
- v = torch.randn((8, 12, 10, 32))
-
- # 进行注意力计算
- res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)
-
- # 输出结果的形状
- # (bs, head, seq_len, dk), (bs, head, seq_len, seq_len)
- print(res.shape, att_scores.shape)
接下来,我们再来看下LLaMA里是怎么实现这个旋转位置编码的,具体而言,LLaMA 的model.py文件里面实现了旋转位置编码(为方便大家理解,我给相关代码 加了下注释)
首先,逐一实现这三个函数
precompute_freqs_cis
reshape_for_broadcast
apply_rotary_emb
- # 预计算频率和复数的函数
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # 计算频率
- t = torch.arange(end, device=freqs.device) # 根据结束位置生成序列
- freqs = torch.outer(t, freqs).float() # 计算外积得到新的频率
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # 计算复数
- return freqs_cis # 返回复数
- # 重塑的函数
- def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
- ndim = x.ndim # 获取输入张量的维度
- assert 0 <= 1 < ndim # 检查维度的合理性
- assert freqs_cis.shape == (x.shape[1], x.shape[-1]) # 检查复数的形状
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # 计算新的形状
- return freqs_cis.view(*shape) # 重塑复数的形状并返回
- # 应用旋转嵌入的函数
- def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # 将xq视为复数
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # 将xk视为复数
- freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # 重塑复数的形状
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # 计算xq的输出
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) # 计算xk的输出
- return xq_out.type_as(xq), xk_out.type_as(xk) # 返回xq和xk的输出
之后,在注意力机制的前向传播函数中调用上面实现的第三个函数 apply_rotary_emb,赋上位置信息
- # 对Query和Key应用旋转嵌入
- xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
模型名称 | 隐藏层维度 | 层数 | 注意力头数 | 词表大小 | 训练数据(tokens) | 位置编码 | 最大长度 |
Baichuan-7B | 4,096 | 32 | 32 | 64,000 | 1.2 万亿 | RoPE | 4,096 |
Baichuan-13B | 5,120 | 40 | 40 | 64,000 | 1.4 万亿 | ALiBi | 4,096 |
Baichuan 2-7B | 4096 | 32 | 32 | 125,696 | 2.6万亿 | RoPE | 4096 |
Baichuan 2-13B | 5120 | 40 | 40 | 125,696 | 2.6万亿 | ALiBi | 4096 |
注意看上表的位置编码那一列,baichuan 7B无论第一代还是第二代,位置编码均用的RoPE,而baichuan 13B则无论是第一代还是第二代,均用的ALiBi
下面便详细介绍下该ALiBi
ALiBi全称是Attention with Linear Biases,通过论文《Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation》提出,其不像标准transformer那样,在embedding层添加位置编码,而是在softmax的结果后添加一个静态的不可学习的偏置项(说白了,就是数值固定)
具体而言,如下图所示
左边是自注意力得分,关于q和k的内积
右边是一个相对距离的矩阵,
q1 q2 q3 q4 q5
k1 k2 k3 k4 k5
所以才有
→ q1和k1之间的距离是0,所以对应位置就是0
→ q2和k1之间的距离是「相对位置偏移为“k的索引”1」 - 「q的索引2」,得到1-2 = -1,就对应到了中间矩阵的取值为-1了
以此类推,相对距离矩阵的中间对角线上都是0,然后左下角的取值都是对应的「k的索引」-「q的索引」了
// 待更
9月底,GenAI, Meta正式发布LLaMA 2 Long(这是其论文《Effective Long-Context Scaling of Foundation Models》),与LLaMA 2相比,LLaMA 2 Long的变化主要体现在以下两点
在LLaMA 2中,它的位置编码采用的是旋转编码RoPE方法,其通过旋转矩阵来实现位置编码的外推
然,Meta的研究人员通过对70亿规模的LLaMA 2进行实验,确定了LLaMA 2中的RoPE方法的一个局限性,即,阻止注意力模块聚集远处token的信息
为此,Meta想出了一个非常简单的破解办法:
减少每个维度的旋转角度(which essentially reduces the rotation angles of each dimension)
具体而言就是将超参数“基频(base frequency)b”从10000增加到500000(increasing the “base frequency b” of ROPE from 10, 000 to 500, 000)
在附录中,Meta还通过可视化为螺旋图这一非常有趣的方式,将RoPE ABF与RoPE PI的差异进行了理论分析
总之,与RoPE PI相比,RoPE ABF的优势主要体现在它能以更大的粒度分配嵌入向量(the embedded vectors),从而使模型更容易区分位置
此外,他们还观察到,嵌入向量之间的相对距离既对RoPE PI的关键参数有线性依赖性,也对RoPE ABF的关键参数也有对数依赖性。
这也就是为什么可以很容易地对基频这一超参数“下手”
这一改动立刻奏效,缩小了RoPE对远端token的衰减效应,并且在扩展LLAMA的上下文长度上优于一项类似的名为“位置插值”的方法RoPE PI(如下图所示,RoPE表示基线方法,RoPE ABF为Meta此次发明的新方法,xPos是另一种应用了该方法的旋转编码变体)
然,一个问题是,通过上面这个可视化结果,Meta观察到RoPE在长程区域出现了较大的“振荡”,这对于语言建模来说可能不是个好消息
不过,通过报告几种方法在长序列困惑度和FIRST-SENTENCE-RETRIEVAL两个任务上的表现来看,问题不大
而且,尤其在后者任务上,他们提出的RoPE ABF是唯一一个可以始终保持性能的变体
最终,LLaMA 2 Long凭借着这一改动,达成了3.2万的上下文token,并通过长下文连续预训练的共同作用,获得了开头所示的好成绩:
除了全面超越LLaMA 2、在特定任务上超越Claude 2和ChatGPT,Meta也给出了它和一些开源长下文模型的对比。结果也相当不赖,如下图所示
//待更
最后,说明下为何像开头说的是「23年12.16日这天对本文做了大修」呢,原因在于
如今博客的访问PV2000万,希望明年达到2000万UV以上,以上视为后记
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。