赞
踩
旋转位置编码
旋转位置编码的核心是找到对应的旋转矩阵
LLaMA中旋转矩阵相关代码
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0): # 计算词向量元素两两分组之后,每组元素对应的旋转角度 freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # 生成 token 序列索引 t = [0, 1,..., seq_len-1] t = torch.arange(seq_len, device=freqs.device) # freqs.shape = [seq_len, dim // 2] freqs = torch.outer(t, freqs).float() # torch.polar 的文档 # https://pytorch.org/docs/stable/generated/torch.polar.html # 计算结果是个复数向量 # 假设 freqs = [x, y] # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i] freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # xq.shape = [batch_size, seq_len, dim] # xq_.shape = [batch_size, seq_len, dim // 2, 2] xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2) # 转为复数域 xq_ = torch.view_as_complex(xq_) xk_ = torch.view_as_complex(xk_) # 应用旋转操作,然后将结果转回实数域 # xq_out.shape = [batch_size, seq_len, dim] xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2) return xq_out.type_as(xq), xk_out.type_as(xk) class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.wq = Linear(...) self.wk = Linear(...) self.wv = Linear(...) self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2) def forward(self, x: torch.Tensor): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(batch_size, seq_len, dim) xk = xk.view(batch_size, seq_len, dim) xv = xv.view(batch_size, seq_len, dim) # attention 操作之前,应用旋转位置编码 xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # scores.shape = (bs, seqlen, seqlen) scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim) scores = F.softmax(scores.float(), dim=-1) output = torch.matmul(scores, xv) # (batch_size, seq_len, dim)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。