当前位置:   article > 正文

【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器

sgd优化器

1.简介

很多情况下,我们调用优化器的时候都不清楚里面的原理和构造,主要基于自己数据集和模型的特点,然后再根据别人的经验来选择或者尝试优化器。下面分别对SGD的原理、pytorch代码进行介绍和解析。


2.梯度下降

梯度下降方法可以分为3种,分别是:

  • BGD (Batch gradient descent)

这种方法是最朴素的梯度下降方法,将全部的数据样本输入网络计算梯度后进行一次更新:

w^{^{k+1}} =w^{^{k}}-\alpha *\bigtriangledown f(w^{k})

其中 w为模型参数, \bigtriangledown f(w^{k})为模型参数更新梯度,\alpha为学习率。

这个方法的最大问题就是容易落入局部最优点或者鞍点。

局部最优点很好理解,就是梯度在下降过程中遇到下图的情况,导致在local minimum区间不断震荡最终收敛。

鞍点(saddle point)是指一个非局部极值点的驻点,如下图所示,长得像一个马鞍因此得名。以红点的位置来说,在x轴方向是一个向上弯曲的曲线,在y轴方向是一个向下弯曲的曲线。当点从x轴方向向下滑动时,最终也会落入鞍点,导致梯度为0。

  • SGD (Stochastic gradient descent)

为了解决BGD落入鞍点或局部最优点的问题,SGD引入了随机性,即将每个数据样本输入网络计算梯度后就进行一次更新:

w^{^{k+1}} =w^{^{k}}-\alpha *\bigtriangledown f(w^{k};x^{_{i}};y^{_{i}})

其中 w为模型参数, \bigtriangledown f(w^{k};x^{_{i}};y^{_{i}})为一个样本输入后的模型参数更新梯度,\alpha为学习率。

由于要对每个样本都单独计算梯度,那么相当于引入了许多噪声,梯度下降时就会跳出鞍点和局部最优点。但要对每个样本都计算一次梯度就导致了时间复杂度较高,模型收敛较慢,而且loss和梯度会有大幅度的震荡。

  • MBGD (Mini-batch gradient descent)

MBGD相当于缝合了SGD和BGD,即将多个数据样本输入网络计算梯度后就进行一次更新:

w^{^{k+1}} =w^{^{k}}-\alpha *\bigtriangledown f(w^{k};x^{_{i:i+b}};y^{_{i:i+b}})

其中 w为模型参数, \bigtriangledown f(w^{k};x^{_{i:i+b}};y^{_{i:i+b}})为batch size个样本输入后的模型参数更新梯度,\alpha为学习率。

MBGD同时解决了两者的缺点,使得参数更新更稳定更快速,这也是我们最常用的方法,pytorch代码里SGD类也是指的MBGD(当然可以自己设置特殊的batch size,就会退化为SGD或BGD)。


3.SGD优化

实际在pytorch的代码中,还加了两个优化,分别是:

  • Momentum

从原理上可以很好理解,Momentum就是在当前step的参数更新中加入了部分上一个step的梯度,公式表示为:

v^{k} =\gamma *v^{k-1}-\alpha *\bigtriangledown f(w^{k};x^{_{i:i+b}};y^{_{i:i+b}})

w^{^{k+1}} =w^{^{k}}-v^{^{k}}

其中 v^{^{k}}v^{^{k-1}}为当前step和上一个step的动量,即当前step的动量会有当前step的梯度和上一个step的动量叠加计算而来,其中\gamma一般设置为0.9或者0.99。

我们可以从以下两幅示意图中看到区别,第一张图没有加Momentum,第二张图加了Momentum。可以看到在第一张图中,点一开始往梯度变化的方向移动,但是到后来梯度逐渐变小,到最后变为了0,所以最终没有到达最优点。而第二张图由于加了Momentum,所以点会有一个横向移动的惯性,即使到了梯度为0的地方也能依靠惯性跳出。

  • Nesterov accelerated gradient(NAG)

加了Momentum之后,实际上模型参数更新的方向就不是当前点的梯度方向,所以这会一定程度上导致模型更新的不准确。NAG方法就是让参数先根据惯性预测出下一步点应该在的位置,然后根据预测点的梯度再更新一次:

w^{^{k{}'}} =w^{^{k}}-\gamma *v^{^{k-1}}

v^{k} =\gamma *v^{k-1}-\alpha *\bigtriangledown f(w^{k{}'};x^{_{i:i+b}};y^{_{i:i+b}})

w^{^{k+1}} =w^{^{k}}-v^{^{k}}


4.pytorch代码

以下代码为pytorch官方SGD代码,其中关键部分在step()中。

  1. import torch
  2. from torch.optim import Optimizer
  3. from torch.optim.optimizer import required
  4. class SGD(Optimizer):
  5. r"""Implements stochastic gradient descent (optionally with momentum).
  6. Nesterov momentum is based on the formula from
  7. `On the importance of initialization and momentum in deep learning`__.
  8. Args:
  9. params (iterable): iterable of parameters to optimize or dicts defining
  10. parameter groups
  11. lr (float): learning rate
  12. momentum (float, optional): momentum factor (default: 0)
  13. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  14. dampening (float, optional): dampening for momentum (default: 0)
  15. nesterov (bool, optional): enables Nesterov momentum (default: False)
  16. Example:
  17. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  18. >>> optimizer.zero_grad()
  19. >>> loss_fn(model(input), target).backward()
  20. >>> optimizer.step()
  21. __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
  22. .. note::
  23. The implementation of SGD with Momentum/Nesterov subtly differs from
  24. Sutskever et. al. and implementations in some other frameworks.
  25. Considering the specific case of Momentum, the update can be written as
  26. .. math::
  27. v = \rho * v + g \\
  28. p = p - lr * v
  29. where p, g, v and :math:`\rho` denote the parameters, gradient,
  30. velocity, and momentum respectively.
  31. This is in contrast to Sutskever et. al. and
  32. other frameworks which employ an update of the form
  33. .. math::
  34. v = \rho * v + lr * g \\
  35. p = p - v
  36. The Nesterov version is analogously modified.
  37. """
  38. def __init__(self, params, lr=required, momentum=0, dampening=0,
  39. weight_decay=0, nesterov=False):
  40. if lr is not required and lr < 0.0:
  41. raise ValueError("Invalid learning rate: {}".format(lr))
  42. if momentum < 0.0:
  43. raise ValueError("Invalid momentum value: {}".format(momentum))
  44. if weight_decay < 0.0:
  45. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  46. defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
  47. weight_decay=weight_decay, nesterov=nesterov)
  48. if nesterov and (momentum <= 0 or dampening != 0):
  49. raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  50. super(SGD, self).__init__(params, defaults)
  51. def __setstate__(self, state):
  52. super(SGD, self).__setstate__(state)
  53. for group in self.param_groups:
  54. group.setdefault('nesterov', False)
  55. def step(self, closure=None):
  56. """Performs a single optimization step.
  57. Arguments:
  58. closure (callable, optional): A closure that reevaluates the model
  59. and returns the loss.
  60. """
  61. loss = None
  62. if closure is not None:
  63. loss = closure()
  64. for group in self.param_groups:
  65. weight_decay = group['weight_decay']
  66. momentum = group['momentum']
  67. dampening = group['dampening']
  68. nesterov = group['nesterov']
  69. for p in group['params']:
  70. if p.grad is None:
  71. continue
  72. d_p = p.grad.data
  73. if weight_decay != 0:
  74. d_p.add_(weight_decay, p.data)
  75. if momentum != 0:
  76. param_state = self.state[p]
  77. if 'momentum_buffer' not in param_state:
  78. buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
  79. else:
  80. buf = param_state['momentum_buffer']
  81. buf.mul_(momentum).add_(1 - dampening, d_p)
  82. if nesterov:
  83. d_p = d_p.add(momentum, buf)
  84. else:
  85. d_p = buf
  86. p.data.add_(-group['lr'], d_p)
  87. return loss

业务合作/学习交流+v:lizhiTechnology

如果想要了解更多优化器相关知识,可以参考我的专栏和其他相关文章:

优化器_Lcm_Tech的博客-CSDN博客

【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器-CSDN博客

【优化器】(二) AdaGrad原理 & pytorch代码解析_adagrad优化器-CSDN博客

【优化器】(三) RMSProp原理 & pytorch代码解析_rmsprop优化器-CSDN博客

【优化器】(四) AdaDelta原理 & pytorch代码解析_adadelta里rho越大越敏感-CSDN博客

【优化器】(五) Adam原理 & pytorch代码解析_adam优化器-CSDN博客

【优化器】(六) AdamW原理 & pytorch代码解析-CSDN博客

【优化器】(七) 优化器统一框架 & 总结分析_mosec优化器优点-CSDN博客

如果想要了解更多深度学习相关知识,可以参考我的其他文章:

【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-CSDN博客

【图像生成】(一) DNN 原理 & pytorch代码实例_pytorch dnn代码-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/318970
推荐阅读
相关标签
  

闽ICP备14008679号