赞
踩
在学习完GPT2之后,从本文开始进入Llama模型系列。
本文介绍Llama模型的改进之RMSNorm(均方根层归一化)。它是由Root Mean Square Layer Normalization论文提出来的,可以参阅其论文笔记1。
层归一化(LayerNorm)对Transformer等模型来说非常重要,它可以帮助稳定训练并提升模型收敛性。LayerNorm针对一个样本所有特征计算均值和方差,然后使用这些来对样本进行归一化:
μ
=
1
H
∑
i
=
1
H
x
i
,
σ
=
1
H
∑
i
=
1
H
(
x
i
−
μ
)
2
,
N
(
x
)
=
x
−
μ
σ
,
h
=
g
⊙
N
(
x
)
+
b
(1)
\mu = \frac{1}{H}\sum_{i=1}^H x_i,\quad \sigma = \sqrt{\frac{1}{H}\sum_{i=1}^H (x_i - \mu)^2}, \quad N(\pmb x) = \frac{\pmb x-\mu}{\sigma},\quad \pmb h = \pmb g \,\odot N(\pmb x) + \pmb b \tag 1
μ=H1i=1∑Hxi,σ=H1i=1∑H(xi−μ)2
,N(x)=σx−μ,h=g⊙N(x)+b(1)
这里
x
=
(
x
1
,
x
2
,
⋯
,
x
H
)
\pmb x = (x_1,x_2,\cdots, x_H)
x=(x1,x2,⋯,xH)表示某个时间步LN层的输入向量表示,向量维度为
H
H
H;
h
\pmb h
h实LN层的输出;
g
,
b
\pmb g,\pmb b
g,b实两个可学习的参数。
为什么层归一化有用?一些解释如下2:
虽然LayerNorm很好,但是它每次需要计算均值和方差。RMSNorm的思想就是移除(1)式中
μ
\mu
μ的计算部分1:
x
ˉ
i
=
x
i
RMS
(
x
)
g
i
RMS
(
x
)
=
1
H
∑
i
=
1
H
x
i
2
(2)
\bar x_i = \frac{x_i }{ \text{RMS}(\pmb x)} g_i \quad \text{RMS}(\pmb x) =\sqrt{\frac{1}{H} \sum_{i=1}^H x_i^2} \tag 2
xˉi=RMS(x)xigiRMS(x)=H1i=1∑Hxi2
(2)
同时在实现也可以移除平移偏置 b \pmb b b。
单看(2)式的话,相当于仅使用 x \pmb x x的均方根来对输入进行归一化,它简化了层归一化的计算,变得更加高效,同时还有可能带来性能上的提升。
RMSNorm的实现很简单:
import torch import torch.nn as nn from torch import Tensor class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) def _norm(self, hidden_states: Tensor) -> Tensor: variance = hidden_states.pow(2).mean(-1, keepdim=True) return hidden_states * torch.rsqrt(variance + self.eps) def forward(self, hidden_states: Tensor) -> Tensor: return self.weight * self._norm(hidden_states.float()).type_as(hidden_states)
torch.rsqrt
是torch.sqrt
的倒数;eps
是一个很小的数,防止除零;hidden_states.float()
确保了标准差计算的精确度和稳定性,然后在forward
方法中,通过.type_as(hidden_states)
将结果转换回原来的数据类型,以保持与输入张量相同的数据类型,使得归一化处理后的结果与输入数据类型一致。
下面通过一个简单的网络来测试一下:
import torch import torch.nn as nn from torch import Tensor class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.linear = nn.Linear(in_features=10, out_features=5) self.rmsnorm = RMSNorm(hidden_size=5) def forward(self, x): x = self.linear(x) x = self.rmsnorm(x) return x net = SimpleNet() input_data = torch.randn(2, 10) # 2个样本,每个样本包含10个特征 output = net(input_data) print("Input Shape:", input_data.shape) print("Output Shape:", output.shape)
Input Shape: torch.Size([2, 10])
Output Shape: torch.Size([2, 5])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。