当前位置:   article > 正文

OneFlow的大模型分片保存和加载策略

模型分片

bdb45bd7fe56af54d919bddcfc6e2af8.jpeg

撰文 | 李响

1、大规模模型分片存储简介

在模型比较小时(如 100G 以下),还有可能采用单机存储。当模型参数量比较大时,要求的样本数也更大,训练后做 dump 出来的模型也会很大,单机肯定放不下。

比如,由 DeepSpeed 和 Megatron 驱动的 Megatron 图灵自然语言生成模型(MT-NLG)具有 5300 亿个参数,是迄今为止训练过的最大和最强大的单片 Transformer 语言模型,支持这样的大规模语言模型需要分片保存和加载,不会使用单机内存。此外,在其他 CV、搜索、推荐和广告类等场景下,读取样本量增多和模型复杂度增加都会带来模型存储上的难题。

本文将介绍 OneFlow 的大模型分片保存、加载策略以及使用方法。

2

OneFlow 模型分片保存和加载

OneFlow 的大模型分片保存和加载的实现基于全局视角(Global View,https://docs.oneflow.org/master/cookies/global_tensor.html)的概念,既利用 Placement 与 SBP 完成模型文件(下文都用 state dict 表示)在各个物理设备上的切分,适用于当模型大到无法在单个设备的内存或显存上容纳下的场景。

flow.utils.global_view.to_global() 接口介绍

为了更好理解下文保存模型和加载模型两个部分的内容,首先对 flow.utils.global_view.to_global() 接口和其实现思路进行分析。

区别于现有的 Tensor.to_global() 模式(可以处理普通的 Tensor,https://oneflow.readthedocs.io/en/master/generated/oneflow.Tensor.to_global.html?highlight=to_global%28%29),提供了多种类型的输入支持,包括 None、Tensor、List、Tuple、nn.Module 的 state dict 、nn.Graph 的 state dict 和几种类型的任意组合,既将 List/Tuple/Dict 中的输入 Tensor 转换为 Global Tensor。值得注意的是,其传入参数中的 SBP 支持用户自定义一个 (x, tensor) -> sbp 的函数来解决不同 Tensor 对应不同 SBP 的需求。

并且,与 to_global() 对应的还有 flow.utils.global_view.to_local() 接口。可以参考 API 文档中关于 to_global() 和 to_local() 更详细的介绍(https://oneflow.readthedocs.io/en/master/utils.global_view.html)。在 flow.utils.global_view.to_global() 的实现(https://github.com/Oneflow-Inc/oneflow/blob/master/python/oneflow/utils/global_view/to_global.py)中,支持了多种输入类型适用于现有的 Tensor.to_global() 接口。实现的整体思路大致为检查输入、广播(空)结构,遍历节点、调用回调函数和返回 to_global() 后的结果。

再回到我们关注的地方,这个接口如何做到模型分片保存和加载?

比如对于模型并行/流水并行,模型的参数分散在多个 Rank 上,在保存模型前通过 flow.utils.global_view.to_global() 将 state dict 里的每个 Tensor 在指定 Placement 上转为 Global Tensor,SBP 的类型为 flow.sbp.split,可以设置在特定维度上的切分。同样的,模型也可以按 Split 被加载。当然,SBP 也可以为 Broadcast,支持不同的 SBP 和 Placement 组合。这样,超大规模模型分片存储的问题就被非常好地解决了。

保存模型

大致了解 flow.utils.global_view.to_global() 接口后,在这一部分演示了如何分片保存模型,代码如下:

  1. # 自定义 get_sbp 函数。
  2. def get_sbp(state_dict, tensor):
  3. if tensor is state_dict["System-Train-TrainStep"]:
  4. return flow.sbp.broadcast
  5. if tensor is state_dict["module_pipeline"]["m_stage3.linear.weight"]:
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/816724
推荐阅读
相关标签
  

闽ICP备14008679号