当前位置:   article > 正文

权重衰减 vs L2正则_adam中加入l2正则化怎么体现

adam中加入l2正则化怎么体现

避免过拟合的方法有很多:early stopping、数据集扩增(Data augmentation)、正则化(Regularization)包括L1、L2(L2 regularization也叫weight decay),dropout。

这里重点讲解 L2正则 和weight decay 的区别:如果有兴趣请查阅论文 Decoupled weight decay regularization

在训练神经网络的时候,由于Adam有着收敛快的特点被广泛使用。但是在很多数据集上的最好效果还是用SGD with Momentum细调出来的。可见Adam的泛化性并不如SGD with Momentum。在这篇文章中指出了Adam泛化性能差的一个重要原因就是Adam中L2正则项并不像在SGD中那么有效,并且通过Weight Decay的原始定义去修正了这个问题。文章表达了几个观点比较有意思。

一、L2正则和Weight Decay并不等价。这两者常常被大家混为一谈。首先两者的目的都是想是使得模型权重接近于0。L2正则是在损失函数的基础上增加L2 norm, 即为[公式] 。而权重衰减则是在梯度更新时直接增加一项, [公式] 。在标准SGD的情况下,通过对衰减系数做变换,可以将L2正则和Weight Decay看做一样。但是在Adam这种自适应学习率算法中两者并不等价。
二、使用Adam优化带L2正则的损失并不有效。如果引入L2正则项,在计算梯度的时候会加上对正则项求梯度的结果。那么如果本身比较大的一些权重对应的梯度也会比较大,由于Adam计算步骤中减去项会有除以梯度平方的累积,使得减去项偏小。按常理说,越大的权重应该惩罚越大,但是在Adam并不是这样。而权重衰减对所有的权重都是采用相同的系数进行更新,越大的权重显然惩罚越大。在常见的深度学习库中只提供了L2正则,并没有提供权重衰减的实现。这可能就是导致Adam跑出来的很多效果相对SGD with Momentum偏差的一个原因。
三、在Adam上,应该使用Weight decay,而不是L2正则。具体请见第二部分的图。

1. L2正则

1.1 L2 正则化与权重衰减

注意: L2 regularization 和 weight decay 只在vanilla SGD优化的情况下是等价的

L2正则化就是在代价函数后面加上一个正则项:
C = C 0 + λ 2 n ∑ w w 2 C = C_0 + \frac {\lambda} {2n} \sum_w w^2 C=C0+2nλww2
其中 C 0 C_0 C0代表原始的代价函数,后面的那一项是L2正则。L2正则是这样来的: 所有参数w的平方和,除以训练样本的大小n。 λ \lambda λ就是正则系数,权衡正则项与 C 0 C_0 C0项的比重。另外还有一个系数1/2, 1/2经常会看到主要是为了后面的求导的结果方便,后面那一项求导会产生一个2, 与1/2相乘刚好凑整为1。 系数 λ \lambda λ就是权重衰减系数。

1.2为什么L2正则可以对权重进行衰减

我们对加入的L2正则化后的代价函数进行推导,
——> 1 先求导:
∂ C ∂ w = ∂ C 0 ∂ w + λ n w \frac {\partial C}{\partial w} = \frac {\partial C_0}{\partial w} + \frac {\lambda}{n} w wC=wC0+nλw
∂ C ∂ b = ∂ C 0 ∂ b \frac {\partial C}{\partial b} = \frac {\partial C_0}{\partial b} bC=bC0

——> 2.分析L2正则对参数导数的影响:
可以看出 λ 2 n ∑ w w 2 \frac{\lambda}{2n}\sum_w w^2 2nλww2对 w的导数有影响,对b的导数没有影响。(其实因为L2正则只计算w 的范数,不加b,因为加上b后模型不好训练,并且也没必要)

——> 3.分析L2正则对参数更新的影响:
上面说了 L2正则项对b的导数没有影响,所以对b的参数更新也没有影响。那我们来看对w的更新是如何产生影响的:
w − − > w − η ∂ C 0 ∂ w − η λ n w w --> w -\eta \frac{\partial C_0}{\partial w} - \frac {\eta\lambda}{n}w w>wηwC0nηλw
= ( 1 − η λ n ) w − η ∂ C 0 ∂ w =(1-\frac{\eta \lambda}{n})w - \eta \frac{\partial C_0}{\partial w} =(1nηλ)wηwC0
在不使用L2正则的时候,求导结果中w前系数为1,现在w前面的系数 1 − η λ n 1-\frac{\eta \lambda}{n} 1nηλ ,因为 η λ n \eta \lambda n ηλn 都是正的,所以 1 − η λ n 1-\frac{\eta \lambda}{n} 1nηλ小于1,它的效果是减小w, 这也就是权重衰减(weight decay)的由来。当然考虑到后面的导数项,w最终的值可能增大也可能减小。
需要关注一下:对于基于mini-batch的随机梯度下降,w和b更新的公式跟上面给出的有点不同:
w − − > ( 1 − η λ n ) w − η m ∑ x ∂ C x ∂ w w --> (1-\frac {\eta \lambda}{n})w - \frac {\eta}{m}\sum_x \frac{\partial C_x}{\partial w} w>(1nηλ)wmηxwCx
b − − > b − η m ∑ x ∂ C x ∂ b b --> b - \frac {\eta}{m}\sum_x \frac{\partial C_x}{\partial b} b>bmηxbCx
对比上面w的更新公式,可以发现后面的那一项变了,变成所有导数加和,乘以 η \eta η 再除以m, m是一个min-batch中样本个数。

1.3 权重衰减的作用(L2正则化)

作用: 可以避免模型过拟合的问题
思考:L2正则项 有让w变小的效果,但是为什么w变小 可以防止过拟合?
原理:
1)从模型复杂度上解释,更小的权值w,从某种意义上说,表示网络的复杂度更低,对数据的拟合更好(这个法则叫做奥卡姆剃刀),而实际应用中,也验证了这一点,L2正则化的效果往往好于未经正则化的效果。
2)从数学方面解释:过拟合的时候,拟合函数的系数往往非常大,为什么?如下图所示,过拟合就是拟合函数需要顾及每一个点,最终形成的拟合函数波动很大。在某些很小的区间里面函数值得变化很剧烈。这就意味着函数在某些小区间里的导数值(绝对值)非常大,由于自变量值可大可小,所以只有系数足够大,才能保证导数值很大。而正则化是通过约束参数的范数使其不要太大,所以可以在一定程度上减少过拟合的情况。
在这里插入图片描述

2. Weight decay VS L2 正则 的区别

注意: L2 regularization 和 weight decay 只在vanilla SGD优化的情况下是等价的

Adam + L2 regularization 并不能实现标准的weight decay

Adam 自动调整学习率,大幅提高了训练速度,也很少需要调整学习率,但是很多资料报告Adam优化的最终精度略低于SGD。问题出在哪里?其实Adam本身没有问题,问题在于大多数DL框架的L2 regularization实现用的是weight decay的方式,而weight decay在于Adam共同使用时 有互相耦合。

weight decay

Adam generally requires more regularization than SGD, so be sure to adjust your regularization hyper-parameters when switching from SGD to Adam.
一般情况下,Adam比SGD需要更大的正则,所以当优化器从SGD转成Adam时 确保调整正则化参数。

第一部分已经讲了 在loss函数里面增加L2正则,对应w更新时 相当于 w衰减 之后再减去梯度(其实这个只是对于 vanilla SGD 时才是对等的)。

我们借用fasta.ai里面的公式:
loss函数里面的添加L2正则: (其中wd 是 weight decay 系数)

final_loss = loss + wd * all_weights.pow(2).sum() / 2

当使用vanilla SGD时,loss函数的 L2正则,等价于 w权重更新:

w = w - lr * w.grad - lr * wd * w

为什么在 这里 要把loss 里面的 L2正则 和 weight decay 分开来讲呢?因为上面提到了 这两个只有在vanilla SGD时 才是对等的。其他的情况都不对等。引用原文的话:

The answer is that they are only the same thing for vanilla SGD, but as soon as we add momentum, or use a more sophisticated optimizer like Adam, L2 regularization (first equation) and weight decay (second equation) become different.

为什么不同? 我们来拿带动量的SGD 来 举个例子

带动量的SGD 上 L2正则 vs weight decay

——>1. 在loss函数里面添加L2正则的情况
根据第一部分的推导,我们可以知道,添加L2 正则就是在w的梯度里面添加 w d ∗ w wd * w wdw, 并不是直接添加到w更新里面减去 l r ∗ w d ∗ w lr * wd * w lrwdw

首先,我们要计算 moving average: (在w.grad 里面添加 w d ∗ w wd*w wdw

moving_avg = alpha * moving_avg + (1-alpha) * (w.grad + wd*w)
w = w - lr * moving_avg

在w 更新时,w - lr * moving_avg 其中 w减去的 与正则有关的有2部分:1. 本次更新moving_avg时的 l r ∗ ( 1 − a l p h a ) ∗ w d ∗ w lr * (1-alpha)*wd*w lr(1alpha)wdw 2. 之前更新 moving_avg里面的 decay.

——>2. 在w更新里面添加weight decay的情况
带动量的SGD的 单独的 weight decay 更新如下, **只在w更新时添加一个 − l r ∗ w d ∗ w -lr* wd * w lrwdw **.

moving_avg = alpha * moving_avg + (1-alpha) * w.grad
w = w - lr * moving_avg - lr * wd * w

结论

从上面可以看到 这里的 w 更新(weight decay)和L2正则 在带动量的SGD上 与 vanilla SGD 上是不一样的。
这种不同,在Adam 里面更严重: L2 正则的情况下,我们在w.grad上添加 w d ∗ w wd*w wdw 来计算moving_avg,然后再更新w.然而,weight decay 只是简单在更新的时候从weight 里面减去 w d ∗ w wd * w wdw。经过试验验证,在Adam上我们应该使用weight decay 而不是 L2正则。**
在这里插入图片描述

we should use weight decay with Adam, and not the L2 regularization that classic deep learning libraries implement

AdamW 上 weight decay

Adam上 应该使用 weight decay,而不是L2正则。那么如在Adam上正确使用weight decay?

下图中的绿色部分就是在Adam中正确引入Weight Decay的方式,称作AdamW
在这里插入图片描述

为什么在 Adam 上 weight decay 比L2正则好?

下图是作者分析为什么Adam + Weightdecay 比 Adam+L2 regularization 效果更好的示意图:

第一排是SGD,第二排是Adam,第一列是 SGD/Adam + L2 正则, 第二列是 SGD/Adam 加上 weight decay. 可以看到 在Adam上L2 正则的结果要比SGD要差,weightdecay 的优化的结果更好。
在这里插入图片描述

这里是对比Adam和AdamW的优化下 training loss 和 test error.

  • 根据第一排2个图片可以看到AdamW training loss 下降的比Adam快,并且 test error 减少的也比Adam快,说明 AdamW收敛更快。
  • 第二排对比的是在Adam 里面添加 weight decay(其实是根据L2正则来的,有耦合)和 AdamW 里面的weight decay的 test error 比对,AdamW 性能比Adam好。最后一个trainning loss 也是 AdamW更好。
    在这里插入图片描述
在AdamW 优化时 调整学习率会有更好的结果

第一排是 Adam +L2 regularization 第二排是 Adam + weight decay
三列对比是为了证明 调整学习率 对 AdamW 优化有提高,第一列是固定学习率,第二列是按epoch[30,60,80]降低学习率,第三列是按照余弦退火的方式调整学习率。 最好的是就是 第二行第三列即第6个图 AdamW + ConsineAnnealing 通过余弦退火的方式调整学习率,得到的 低error 空间更大。
在这里插入图片描述

copy 论文的总结

在这里插入图片描述
在这里插入图片描述

AdamW in pytorch

pytorch的源码是个很好的学习资源,这里面提到了3篇论文分别是A method for stochastic optimization,Decoupled weight decay Regularization, On the convergence of Adam and Beyond.

import math
import torch
from .optimizer import Optimizer


[docs]class AdamW(Optimizer):
    r"""Implements AdamW algorithm.

    The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
    The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay coefficient (default: 1e-2)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False)

    .. _Adam\: A Method for Stochastic Optimization:
        https://arxiv.org/abs/1412.6980
    .. _Decoupled Weight Decay Regularization:
        https://arxiv.org/abs/1711.05101
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=1e-2, amsgrad=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        super(AdamW, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(AdamW, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

[docs]    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # Perform stepweight decay
                p.mul_(1 - group['lr'] * group['weight_decay'])

                # Perform optimization step
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')
                amsgrad = group['amsgrad']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                else:
                    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

                step_size = group['lr'] / bias_correction1

                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118

参考:
正则化方法:防止过拟合,提高泛化能力
weight decay VS L2 正则
fasta.ai

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/349782
推荐阅读
相关标签
  

闽ICP备14008679号