赞
踩
x = x - lr * grad_x
alpha
,可将之看作“decay_rate”,它相乘的 sigma
,保存了全部的历史的grad,即 g^t
。通过调整 alpha
的大小,可权衡新的lr是偏向新grad还是历史grad。eps
和 beta
,这里 beta
对应李宏毅教程就是alpha
,我们还是称之为 decay_rate
。v = decay_rate * v + (1 - decay_rate) * grad_x**2
x = x - lr * grad_x / (np.sqrt(v) + eps)
超参数: decay_rate
, eps
lambda
,每一步真实的梯度 v
,从原求导结果grad
变为了grad + lambda * v_t-1
。v = lambda * v + lr * grad
x = x - v
超参数:lambda
如图所示,是Adam优化器的伪代码。我们详细来看
m
是Momentum,指的是动量,即使用历史梯度平滑过的梯度; v
是RMSProp式中的sigma
(见李宏毅RMSProp部分的slide截图),即记录了全部历史grad,并用此进行梯度的指数加权平均。m
即动量momentum,并涉及第一个超参数 beta_1
。其实这里算法和momentum是几乎一样的,所以beta_1
这里的作用也是对历史梯度和当前梯度的衡量;sigma
计算步骤完全一样,略m
v
结果整合到一起,更新参数的梯度m = beta_1 * m + (1 - beta_1) * grad_x
v = beta_2 * v + (1 - beta_2) * grad_x**2
x = x - lr * m / (np.sqrt(v) + eps)
简单来说,AdamW就是Adam优化器加上L2正则,来限制参数值不可太大,这一点属于机器学习入门知识了。以往的L2正则是直接加在损失函数上,比如这样子:
L
o
s
s
=
L
o
s
s
+
1
2
λ
∑
θ
i
∈
Θ
θ
i
2
Loss = Loss + \frac{1}{2}\lambda\sum_{\theta_i \in \Theta}\theta_i ^2
Loss=Loss+21λθi∈Θ∑θi2
但AdamW稍有不同,如下图所示:
粉色部分,为传统L2正则施加的位置;而AdamW,则将正则加在了绿色位置。至于为何这么做?直接摘录BERT里面的原话看看——
# Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. # Add weight decay at the end (fixed version)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
总之就是说,如果直接将L2正则加到loss上去,由于Adam优化器的后序操作,该正则项将会与m
和v
产生奇怪的作用。具体怎么交互的就不求甚解了,求导计算一遍应该即可得知。
因而,AdamW选择将L2正则项加在了Adam的m
和v
等参数被计算完之后、在与学习率lr相乘之前,所以这也表明了weight_decay和L2正则虽目的一致、公式一致,但用法还是不同,二者有着明显的差别。以BERT中的AdamW代码为例,具体是怎么做的一望便知:
step_size = group['lr']
if group['correct_bias']: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
if group['weight_decay'] > 0.0:
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
如code,注意BERT这里的step_size
就是当前的learning_rate。而最后两行就涉及weight_decy
的计算。如果我们将AdamW伪代码第12行的公式稍加化简,会发现实际上这一行大概是这样的:
θ
t
=
θ
t
−
1
−
l
r
∗
g
r
a
d
θ
−
l
r
∗
λ
∗
θ
t
−
1
\theta_t = \theta_{t-1} - lr * grad_\theta - lr * \lambda* \theta_{t-1}
θt=θt−1−lr∗gradθ−lr∗λ∗θt−1
此处
λ
\lambda
λ就是weight_decay
。再将上述公式最后一项和code最后一行对比,是不是一模一样呢。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。