当前位置:   article > 正文

梯度爆炸解决方案——梯度截断(gradient clip norm)_clip gradient norm

clip gradient norm

如果梯度超过阈值,那么就截断,将梯度变为阈值

from torch.nn.utils import clip_grad_norm

pytorch源码

默认为l2(norm type)范数,对网络所有参数求l2范数,和最大梯度阈值相比,如果clip_coef<1,范数大于阈值,则所有梯度值乘以系数。

使用:

  1. optimizer.zero_grad()
  2. loss, hidden = model(data, hidden, targets)
  3. loss.backward()
  4. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
  5. optimizer.step()

python - How to properly do gradient clipping in pytorch? - Stack Overflow  https://stackoverflow.com/questions/54716377/how-to-properly-do-gradient-clipping-in-pytorch

但是,clip_grad_norm还不够狠,有时候失效,这个时候更狠的就出来了:

torch.nn.utils.clip_grad_value_(model.parameters(), number)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/311533
推荐阅读
相关标签
  

闽ICP备14008679号