当前位置:   article > 正文

PyTorch基于Apex的混合精度加速_apex pytorch对应关系

apex pytorch对应关系

 

安装:pip install apex

参考: https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/100135729

在这篇文章里,笔者会详解一下混合精度计算(Mixed Precision),并介绍一款 NVIDIA 开发的基于 PyTorch 的混合精度训练加速神器——Apex,最近 Apex 更新了 API,可以用短短三行代码就能实现不同程度的混合精度加速,训练时间直接缩小一半。 


话不多说,直接先教你怎么用。


PyTorch实现
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()

对,就是这么简单,如果你不愿意花时间深入了解,读到这基本就可以直接使用起来了。


但是如果你希望对 FP16 和 Apex 有更深入的了解,或是在使用中遇到了各种不明所以的“Nan”的同学,可以接着读下去,后面会有一些有趣的理论知识和笔者最近一个月使用 Apex 遇到的各种 bug,不过当你深入理解并解决掉这些 bug 后,你就可以彻底摆脱“慢吞吞”的 FP32 啦。

理论部分
为了充分理解混合精度的原理,以及 API 的

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/喵喵爱编程/article/detail/896628
推荐阅读
相关标签
  

闽ICP备14008679号