当前位置:   article > 正文

十分钟读懂旋转编码(RoPE)

rotaryembedding

9f6bf79b0f61e05459053aa2adaa3ed3.gif

©作者 | 绝密伏击

单位 | 奇虎360高级算法专家

旋转位置编码(Rotary Position Embedding,RoPE)是论文 Roformer: Enhanced Transformer With Rotray Position Embedding 提出的一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。而目前很火的 LLaMA、GLM 模型也是采用该位置编码方式。

和相对位置编码相比,RoPE 具有更好的外推性,目前是大模型相对位置编码中应用最广的方式之一。

备注:什么是大模型外推性?

外推性是指大模型在训练时和预测时的输入长度不一致,导致模型的泛化能力下降的问题。例如,如果一个模型在训练时只使用了 512 个 token 的文本,那么在预测时如果输入超过 512 个 token,模型可能无法正确处理。这就限制了大模型在处理长文本或多轮对话等任务时的效果。

1520187559e3d9e2fa63c84f6fcb89c0.png

旋转编码RoPE

1.1 基本概念

在介绍 RoPE 之前,先给出一些符号定义,以及基本背景。

首先定义一个长度为 的输入序列为:

33d78e25d2c90cc72d46780bbbf8861d.png

其中 表示输入序列中第 个 token,而输入序列 对应的 embedding 表示为:

888a662b1acda408ca9241531c14ce59.png

其中 表示第 个 token 对应的 维词嵌入向量。

接着在做 self-attention 之前,会用词嵌入向量计算 向量同时加入位置信息,函数公式表达如下:

c3d5d73a22117ef6fddacb7b94badc1a.png

其中 表示第 个 token 对应的词向量 集成位置信息 之后的 query 向量。而 和 则表示第 个 token 对应的词向量 集成位置信息 之后的 key 和 value 向量。

而基于 transformer 的位置编码方法都是着重于构造一个合适的 函数形式。

而计算第 个词嵌入向量 对应的 self-attention 输出结果,就是 和其他 都计算一个 attention score ,然后再将 attention score 乘以对应的 再求和得到输出向量 :

29e91d35becd8864f9ff95a1024620b7.png

1.2 绝对位置编码

对于位置编码,常规的做法是在计算 query,key 和 value 向量之前,会计算一个位置编码向量 加到词嵌入 上,位置编码向量 同样也是 维向量,然后再乘以对应的变换矩阵 :

31f47446545e2c81a0366b59f8998914.png

而经典的位置编码向量 的计算方式是使用 Sinusoidal 函数:

41bdbb492526c20fd2c6972334ba17f9.png

其中 表示位置 维度向量 中的第 位置分量也就是偶数索引位置的计算公式,而 就对应第 位置分量也就是奇数索引位置的计算公式。

1.3 2维旋转位置编码

论文中提出为了能利用上 token 之间的相对位置信息,假定 query 向量 和 key 向量 之间的内积操作可以被一个函数 表示,该函数 的输入是词嵌入向量 , 和它们之间的相对位置 :

8f8f717834af19c6d7c23a05511697da.png

接下来的目标就是找到一个等价的位置编码方式,从而使得上述关系成立。

假定现在词嵌入向量的维度是两维 ,这样就可以利用上 2 维度平面上的向量的几何性质,然后论文中提出了一个满足上述关系的 和 的形式如下:

08ef03c095e96b520b7cf8bc5177a37e.png

这里面 Re 表示复数的实部。

进一步地, 可以表示成下面的式子:

ff760586d00a28bc10d0f9f87f028df9.png

看到这里会发现,这不就是 query 向量乘以了一个旋转矩阵吗?这就是为什么叫做旋转位置编码的原因。

同理, 可以表示成下面的式子:

abd722e75099ac9dd970106cc3bb8788.png

最终 可以表示如下:

3be323ef4d55caab404736adcdd91e3c.png

关于上面公式(8)~(11)的具体推导,可以参见文章最后的附录,或者参考文章:一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)。

1.4 扩展到多维

将2维推广到任意维度,可以表示如下:

398c7afe2a51138ef3b6a886c556adcc.png

内积满足线性叠加性,因此任意偶数维的 RoPE,我们都可以表示为二维情形的拼接,即

8a04aa5b9ed260550efee202b216b4eb.png

将 RoPE 应用到前面公式(4)的 Self-Attention 计算,可以得到包含相对位置信息的 Self-Attetion:

dd620689367c0528215a9ca6a71c90ea.png

其中,。

值得指出的是,由于 是一个正交矩阵,它不会改变向量的模长,因此通常来说它不会改变原模型的稳定性。

1.5 RoPE 的高效计算

由于 的稀疏性,所以直接用矩阵乘法来实现会很浪费算力,推荐通过下述方式来实现 RoPE:

e13056935124cda9407af948ec541002.png

其中 是逐位对应相乘,即计算框架中的 运算。从这个实现也可以看到,RoPE 可以视为是乘性位置编码的变体。

总结来说,RoPE 的 self-attention 操作的流程是:对于 token 序列中的每个词嵌入向量,首先计算其对应的 query 和 key 向量,然后对每个 token 位置都计算对应的旋转位置编码,接着对每个 token 位置的 query 和 key 向量的元素按照两两一组应用旋转变换,最后再计算 query 和 key 之间的内积得到 self-attention 的计算结果。

论文中有个很直观的图片展示了旋转变换的过程:

77b460d0e8f7b89408c5d0f56955a380.png

1.6 远程衰减

可以看到,RoPE 形式上和前面公式(6)Sinusoidal 位置编码有点相似,只不过 Sinusoidal 位置编码是加性的,而 RoPE 可以视为乘性的。在 的选择上,RoPE 同样沿用了 Sinusoidal 位置编码的方案,即 ,它可以带来一定的远程衰减性。

具体证明如下:将 两两分组后,它们加上 RoPE 后的内积可以用复数乘法表示为:

e57495087cae1a35ff3d3bd0ec7c3398.png

f452476834d89a142d73623b39a357e6.png

并约定 ,那么由 Abel 变换(分部求和法)可以得到:

a1f3c2a0f873c272f90d2edf3dcffa68.png

所以

f67dcae11c64d9b755ad87d32c01a368.png

因此我们可以考察 随着相对距离的变化情况来作为衰减性的体现:

eb1f3c84a487e2e25bd20694547431c4.png

从图中我们可以看到随着相对距离的变大,内积结果有衰减趋势的出现。因此,选择 ,确实能带来一定的远程衰减性。论文中还试过以 为初始化,将 视为可训练参数,然后训练一段时间后发现 并没有显著更新,因此干脆就直接固定 了。

a68a1f87918c080d8cba840ed9cbb592.png

RoPE实验

我们看一下 RoPE 在预训练阶段的实验效果:

e86d1589f33b1989a4b25c689913346d.png

从上面可以看出,增大序列长度,预训练的准确率反而有所提升,这体现了 RoPE 具有良好的外推能力。

下面是在下游任务上的实验结果:

1c4eb102b49fb8f95ffbbde37f275c36.png

其中 RoFormer 是一个绝对位置编码替换为 RoPE 的 WoBERT 模型,后面的参数(512)是微调时截断的maxlen,可以看到 RoPE 确实能较好地处理长文本语义。

c8c3153be5a49e1a58b0a61b94f11b71.png


RoPE代码实现

Meta 的 LLAMA 和 清华的 ChatGLM 都使用了 RoPE 编码,下面看一下具体实现。

3.1 在LLAMA中的实现

  1. # 生成旋转矩阵
  2. def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
  3.     # 计算词向量元素两两分组之后,每组元素对应的旋转角度\theta_i
  4.     freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
  5.     # 生成 token 序列索引 t = [01,..., seq_len-1]
  6.     t = torch.arange(seq_len, device=freqs.device)
  7.     # freqs.shape = [seq_len, dim // 2] 
  8.     freqs = torch.outer(t, freqs).float()  # 计算m * \theta
  9.     # 计算结果是个复数向量
  10.     # 假设 freqs = [x, y]
  11.     # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
  12.     freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 
  13.     return freqs_cis
  14. # 旋转位置编码计算
  15. def apply_rotary_emb(
  16.     xq: torch.Tensor,
  17.     xk: torch.Tensor,
  18.     freqs_cis: torch.Tensor,
  19. ) -> Tuple[torch.Tensor, torch.Tensor]:
  20.     # xq.shape = [batch_size, seq_len, dim]
  21.     # xq_.shape = [batch_size, seq_len, dim // 2, 2]
  22.     xq_ = xq.float().reshape(*xq.shape[:-1], -12)
  23.     xk_ = xk.float().reshape(*xk.shape[:-1], -12)
  24.     # 转为复数域
  25.     xq_ = torch.view_as_complex(xq_)
  26.     xk_ = torch.view_as_complex(xk_)
  27.     # 应用旋转操作,然后将结果转回实数域
  28.     # xq_out.shape = [batch_size, seq_len, dim]
  29.     xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
  30.     xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
  31.     return xq_out.type_as(xq), xk_out.type_as(xk)
  32. class Attention(nn.Module):
  33.     def __init__(self, args: ModelArgs):
  34.         super().__init__()
  35.         self.wq = Linear(...)
  36.         self.wk = Linear(...)
  37.         self.wv = Linear(...)
  38.         self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)
  39.     def forward(self, x: torch.Tensor):
  40.         bsz, seqlen, _ = x.shape
  41.         xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
  42.         xq = xq.view(batch_size, seq_len, dim)
  43.         xk = xk.view(batch_size, seq_len, dim)
  44.         xv = xv.view(batch_size, seq_len, dim)
  45.         # attention 操作之前,应用旋转位置编码
  46.         xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
  47.         # scores.shape = (bs, seqlen, seqlen)
  48.         scores = torch.matmul(xq, xk.transpose(12)) / math.sqrt(dim)
  49.         scores = F.softmax(scores.float(), dim=-1)
  50.         output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)
  51.   # ......

这里举一个例子,假设 batch_size=10, seq_len=3, d=8,则调用函数 precompute_freqs_cis(d, seq_len) 后,生成结果为:

  1. In [239]: freqs_cis
  2. Out[239]: 
  3. tensor([[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
  4.         [ 0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j],
  5.         [-0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j]])

以结果中的第二行为例(对应的 m = 1),也就是:

243785c3dbe9e9d95d6f9ec60eb3935c.png

最终按照公式(12)可以得到编码之后的 。

注意:在代码中是直接用 freqs_cis[0] * xq_[0] 的结果表示第一个 token 对应的旋转编码(和公式 12 计算方式有所区别)。其中将原始的 query 向量 转换为了复数形式。

  1. In [351]: q_ = q.float().reshape(*q.shape[:-1], -12)
  2. In [352]: q_[0]
  3. Out[352]: 
  4. tensor([[[ 1.0247,  0.4782],
  5.          [ 1.5593,  0.2119],
  6.          [ 0.4175,  0.5309],
  7.          [ 0.4858,  0.1850]],
  8.         [[-1.7456,  0.6849],
  9.          [ 0.3844,  1.1492],
  10.          [ 0.1700,  0.2106],
  11.          [ 0.5433,  0.2261]],
  12.         [[-1.1206,  0.6969],
  13.          [ 0.8371-0.7765],
  14.          [-0.3076,  0.1704],
  15.          [-0.5999-1.7029]]])
  16. In [353]: xq = torch.view_as_complex(q_)
  17. In [354]: xq[0]
  18. Out[354]: 
  19. tensor([[ 1.0247+0.4782j,  1.5593+0.2119j,  0.4175+0.5309j,  0.4858+0.1850j],
  20.         [-1.7456+0.6849j,  0.3844+1.1492j,  0.1700+0.2106j,  0.5433+0.2261j],
  21.         [-1.1206+0.6969j,  0.8371-0.7765j, -0.3076+0.1704j, -0.5999-1.7029j]])

这里为什么可以这样计算?

主要是利用了复数的乘法性质。

我们首先来复习一下复数乘法的性质:

efc9982013970d210b374e68ee99b3b2.png

因此要计算:

3b041148ba45af2affa6a36d5fe57431.png

可以转化为计算:

79282e205120cc2baf203d03bacf9e09.png

所以可以将公式(12)转化为两个复数的乘法运算。

3.2 在ChatGLM中的实现

和 LLAMA 的实现方式相差不大。代码如下:

  1. class RotaryEmbedding(torch.nn.Module):
  2.     def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
  3.         super().__init__()
  4.          # 计算 \theta_i
  5.         inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
  6.         inv_freq = inv_freq.half()
  7.         self.learnable = learnable
  8.         if learnable:
  9.             self.inv_freq = torch.nn.Parameter(inv_freq)
  10.             self.max_seq_len_cached = None
  11.         else:
  12.             self.register_buffer('inv_freq', inv_freq)
  13.             self.max_seq_len_cached = None
  14.             self.cos_cached = None
  15.             self.sin_cached = None
  16.         self.precision = precision
  17.     def forward(self, x, seq_dim=1, seq_len=None):
  18.         if seq_len is None:
  19.             seq_len = x.shape[seq_dim]
  20.         if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
  21.             self.max_seq_len_cached = None if self.learnable else seq_len
  22.             # 生成 token 序列索引 t = [01,..., seq_len-1]
  23.             t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
  24.             # 对应m * \theta
  25.             freqs = torch.einsum('i,j->ij', t, self.inv_freq)
  26.             # 将 m * \theta 拼接两次,对应复数的实部和虚部
  27.             emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  28.             if self.precision == torch.bfloat16:
  29.                 emb = emb.float()
  30.             # [sx, 1 (b * np), hn]
  31.             cos_cached = emb.cos()[:, None, :]  # 计算得到cos(m*\theta)
  32.             sin_cached = emb.sin()[:, None, :]  # 计算得到cos(m*\theta)
  33.             if self.precision == torch.bfloat16:
  34.                 cos_cached = cos_cached.bfloat16()
  35.                 sin_cached = sin_cached.bfloat16()
  36.             if self.learnable:
  37.                 return cos_cached, sin_cached
  38.             self.cos_cached, self.sin_cached = cos_cached, sin_cached
  39.         return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
  40.     def _apply(self, fn):
  41.         if self.cos_cached is not None:
  42.             self.cos_cached = fn(self.cos_cached)
  43.         if self.sin_cached is not None:
  44.             self.sin_cached = fn(self.sin_cached)
  45.         return super()._apply(fn)
  46. def rotate_half(x):
  47.     x1, x2 = x[..., :x.shape[-1// 2], x[..., x.shape[-1] // 2:]
  48.     return torch.cat((-x2, x1), dim=x1.ndim - 1)

7d357597ed4de3f10508b45761e056de.png


RoPE的外推性

我们都知道 RoPE 具有很好的外推性,前面的实验结果也证明了这一点。这里解释下具体原因。

RoPE 可以通过旋转矩阵来实现位置编码的外推,即可以通过旋转矩阵来生成超过预期训练长度的位置编码。这样可以提高模型的泛化能力和鲁棒性。

我们回顾一下 RoPE 的工作原理:假设我们有一个 维的绝对位置编码 ,其中 是位置索引。我们可以将 看成一个 维空间中的一个点。我们可以定义一个  维空间中的一个旋转矩阵 ,它可以将任意一个点沿着某个轴旋转一定的角度。我们可以用 来变换 ,得到一个新的点 。我们可以发现, 和 的距离是相等的,即 。这意味着 和 的相对关系没有改变。但是, 和 的距离可能发生改变,即 。这意味着 和 的相对关系有所改变。因此,我们可以用 来调整不同位置之间的相对关系。

如果我们想要生成超过预训练长度的位置编码,我们只需要用 来重复变换最后一个预训练位置编码 ,得到新的位置编码

20ed4d3a9a82d857b1593c157a148300.png

依此类推。这样就可以得到任意长度的位置编码序列 ,其中 可以大于 。由于 是一个正交矩阵,它保证了 和 的距离不会无限增大或缩小,而是在一个有限范围内波动。这样就可以避免数值溢出或下溢的问题。同时,由于 是一个可逆矩阵,它保证了 和 的距离可以通过 的逆矩阵 还原到 和 的距离,即

d84f7c4196b04e108608076010f2ce22.png

这样就可以保证位置编码的可逆性和可解释性。

总结而言:

旋转编码 RoPE 可以有效地保持位置信息的相对关系,即相邻位置的编码之间有一定的相似性,而远离位置的编码之间有一定的差异性。这样可以增强模型对位置信息的感知和利用。这一点是其他绝对位置编码方式(如正弦位置编码、学习的位置编码等)所不具备的,因为它们只能表示绝对位置,而不能表示相对位置。

旋转编码 RoPE 可以通过旋转矩阵来实现位置编码的外推,即可以通过旋转矩阵来生成超过预训练长度的位置编码。这样可以提高模型的泛化能力和鲁棒性。这一点是其他固定位置编码方式(如正弦位置编码、固定相对位置编码等)所不具备的,因为它们只能表示预训练长度内的位置,而不能表示超过预训练长度的位置。

旋转编码 RoPE 可以与线性注意力机制兼容,即不需要额外的计算或参数来实现相对位置编码。这样可以降低模型的计算复杂度和内存消耗。这一点是其他混合位置编码方式(如 Transformer-XL、XLNet 等)所不具备的,因为它们需要额外的计算或参数来实现相对位置编码。

20e8121431bc8f217a90198895bc7618.png

总结

最近一直听到旋转编码这个词,但是一直没有仔细看具体原理。今天花时间仔细看了一遍,确实理论写的比较完备,而且实验效果也不错。目前很多的大模型,都选择了使用了这种编码方式(LLAMA、GLM 等)。

dc116286e607a21e5ce0b4dbba7615e8.png


附录

这里补充一下前面公式 1.3.2 节中,公式(8)~(11)是怎么推导出来的。

回到之前的公式(8),编码之后的 以及内积 的形式如下:

da37fb497be9bebe603504432560efa1.png

上面的公式为什么满足:

9ee21acccae7d9ebb7dae7ebbe1afa76.png

首先我们得先了解一下基本的复数相关知识。

首先看到上述 和 公式中有个指数函数: 

这个其实是欧拉公式,其中 表示任意实数, 是自然对数的底数, 是复数中的虚数单位,则根据欧拉公式有:

0d7e4309525e1228d593e07c3fd3dd6c.png

则是上述指数函数可以表示为实部为 ,虚部为 的一个复数,欧拉公式建立了指数函数、三角函数和复数之间的桥梁。

则上述 和 公式的

49bd62ab5431a1a9e78545770aed62a4.png

然后我们看回公式:

7b43a769f4344bf39a911071e23f1595.png

其中 是个二维矩阵, 是个二维向量,相乘的结果也是一个二维向量,这里用 表示:

b840d7372c3aacdd4ee797ee78dc3c27.png

然后首先将 表示成复数形式:

310ccd9dab945a89143bc9835f49c13f.png

接着

4f4fd1931d47627923dc6ada2458a663.png

其实就是两个复数相乘:

2987fa27e260188e1cf36cea6d13944b.png

然后就有:

f302842d6e8a8fbc11d840a6d73c65ec.png

将结果重新表达成实数向量形式就是:

1372fc5be480d46757ced326f8daaa3f.png

这里不难发现就是 query 向量乘以了一个旋转矩阵。

98c835752f984559c03b43950c193a26.png

这就是为什么叫做旋转式位置编码的原因。

同理可得 key 向量 :

78572946ac8b53cc08520ed366b31813.png

最后还有个函数 :

e30e7f5ad2ecd17a886e45ab254ba8c7.png

其中 表示一个复数 的实部部分,而 则表示复数 的共轭。

复习一下共轭复数的定义:

8d9e3c3960449fc28068f67ac7afe437.png

所以可得:

4db9319c24bc582db9e32a72d9ace90f.png

继续可得:

bebd3361c701a622e0b7fbbaef3ae53f.png

接下来我们就要证明函数 的计算公式是成立的。

首先回顾一下 attention 操作,位置 的 query 和位置 的 key 会做一个内积操作:

51101cda16323f7fa5ab60f52a6817b4.png

接着进行推导,我们整理一下:

07cdd485604b2d01d0cafb81e2e87dc7.png

这就证明上述关系是成立的,位置 的 query 和位置 的 key 的内积就是函数 。

把上面的式子用矩阵向量乘的形式来表达就是:

39fa43ed77a9047ea4aa3c14e06499f5.png

outside_default.png

参考文献

outside_default.png

[1] ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING https://arxiv.org/pdf/2104.09864.pdf

[2] 梁德澎:一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)https://zhuanlan.zhihu.com/p/642884818

[3] 马梦之:一步一步,推导旋转位置编码(Rotary Position Embedding, RoPE)https://zhuanlan.zhihu.com/p/644585013

[4] Transformer升级之路:博采众长的旋转式位置编码

更多阅读

c8dca41d249b6a9379ed3b1303536657.png

474c5ba8b2ceb4ac64a6519356948b92.png

0f84c808666d814bd870809b1cb546a4.png

5b6fe4d801f245637a8574b0a96a15ec.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/371421
推荐阅读
相关标签