当前位置:   article > 正文

PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速

from apex import amp

640

作者丨Nicolas

单位丨追一科技AI Lab研究员

研究方向丨信息抽取、机器阅读理解

你想获得双倍训练速度的快感吗? 

你想让你的显卡内存瞬间翻倍吗? 

如果告诉你只需要三行代码即可实现,你信不? 

在这篇文章里,笔者会详解一下混合精度计算(Mixed Precision),并介绍一款 NVIDIA 开发的基于 PyTorch 的混合精度训练加速神器——Apex,最近 Apex 更新了 API,可以用短短三行代码就能实现不同程度的混合精度加速,训练时间直接缩小一半。 

话不多说,直接先教你怎么用。

PyTorch实现

  1. from apex import amp
  2. model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
  3. with amp.scale_loss(loss, optimizer) as scaled_loss:
  4. scaled_loss.backward()

对,就是这么简单,如果你不愿意花时间深入了解,读到这基本就可以直接使用起来了。

但是如果你希望对 FP16 和 Apex 有更深入的了解,或是在使用中遇到了各种不明所以的“Nan”的同学,可以接着读下去,后面会有一些有趣的理论知识和笔者最近一个月使用 Apex 遇到的各种 bug,不过当你深入理解并解决掉这些 bug 后,你就可以彻底摆脱“慢吞吞”的 FP32 啦。

理论部分

为了充分理解混合精度的原理,以及 API 的使用,先补充一点基础的理论知识。

1. 什么是FP16?

半精度浮点数是一种计算机使用的二进制浮点数数据类型,使用 2 字节(16 位)存储。

640?wx_fmt=png

▲ FP16和FP32表示的范围和精度对比

 

其中, sign 位表示正负, exponent 位表示指数640?wx_fmt=png, fraction 位表示的是分数640?wx_fmt=png。其中当指数为零的时候,下图加号左边为 0,其他情况为 1。

640?wx_fmt=png

▲ FP16的表示范例

 

2. 为什么需要FP16?

在使用 FP16 之前,我想再赘述一下为什么我们使用 FP16。

  • 减少显存占用 现在模型越来越大,当你使用 Bert 这一类的预训练模型时,往往模型及模型计算就占去显存的大半,当想要使用更大的 Batch Size 的时候会显得捉襟见肘。由于 FP16 的内存占用只有 FP32 的一半,自然地就可以帮助训练过程节省一半的显存空间。

  • 加快训练和推断的计算 与普通的空间时间 Trade-off 的加速方法不同&#x

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/896615
推荐阅读
相关标签
  

闽ICP备14008679号