赞
踩
最近在看Gemma代码感觉比LLama的代码看的方便点, 看到RoPE代码跟常规的方式不太一样(也不算常规,就是我理解的方式),特此记录一下。我的RoPE入门代码参考:Rotary Position Embedding (RoPE, 旋转式位置编码) | 原理讲解+torch代码实现
原理我就不讲了,直接贴一下图,图源自于上面的链接。
我们先粘贴一下代码,逐步讲解:
dim:单头维度信息 end:序列长度 theta:10000 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """Precomputes the frequency cis.""" freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis x:输入特征维度[batch, end, num_head, dim] freqs_cis:上个函数获取的结果 def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """Applies the rotary embedding to the query and key tensors.""" x_ = torch.view_as_complex( torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1)) x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2) return x_out
import torch
a = torch.arange(10)
print(a)
b = torch.chunk(a, 2, dim=-1)
print(b)
c = torch.stack(b, dim = -1)
print(c)
并且用torch.view_as_complex转为复述的形式
下面为LLama的RoPE实现:
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
LLama我理解,他则是采用类似于奇偶分开的方式,我简单尝试了一下:
import torch
a = torch.arange(10)
print(a)
b = a.reshape(5,2)
print(b)
以上就是我对RoPE代码实现的理解,相比原来理解的方式,这种相对更加简洁,但是略有一些绕
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。