当前位置:   article > 正文

适配PyTorch FX,OneFlow让量化感知训练更简单_量化感知训练 bn pytorch

量化感知训练 bn pytorch

4593408479ded673cd6b55042f503b11.jpeg

作者 | 刘耀辉

审稿 | BBuf许啸宇

1

背景

近年来,量化感知训练是一个较为热点的问题,可以大大优化量化后训练造成精度损失的问题,使得训练过程更加高效。

Torch.fx在这一问题上走在了前列,使用纯Python语言实现了对于Torch.nn.Module的解析和向IR的转换,也可以提供变换后的IR对应的Python代码,在外部则是提供了简洁易用的API,大大方便了量化感知训练过程的搭建。此外,Torch.fx也有助于消除动态图和静态图之间的Gap,可以比较方便地对图进行操作以及进行算子融合。

OneFlow紧随其后添加了针对OneFlow的fx,即One-fx,在安装One-fx之后,用户可以直接调用oneflow.fx,也可以直接通过import onefx as fx进行使用。

one-fx地址:
https://github.com/Oneflow-Inc/one-fx

One-fx实现代码中绝大部分是对于Torch.fx的fork,但根据OneFlow和PyTorch之间存在的差别进行了一些适配或优化。本文将围绕One-fx适配方式以及在OneFlow中的应用展开。

2

FX主要模块

  • Symbolioc Trace

  • Graph Module

  • Interpreter

  • Proxy

  • Passes

其中,前4个模块共同实现了fx的基本功能,Graph Module和Proxy又是Symbolic Trace的基础,Passes则是在此基础上的扩充。

a10f10e09b076964f9a80420149b4d48.png

Symbolic Trace的基本概念如上图所示,最基本的模型运行过程就是从模型定义到模型执行这样一个流程。

fx则是进行了非侵入式的解析,将模型执行过程转成一张图,这张图中包含了很多个Node,每一个Node都包含了模型中的子模块或者函数调用信息,然后用户可以很方便地获取到所有的Node,并对其进行一些变换操作,最后通过GraphModule重新生成一个模型定义,并对其执行。

其中,在进行模型解析的时候,节点之间变量传递也均使用代理后的变量,如y = oneflow.relu(x),实际上x和y是Proxy(x)和Proxy(y)。

3

One-fx实现方式

这里给出一个Fx最简单的用例,以方便后续对于实现方式的介绍。

 
 
  1. import oneflow
  2. class MyModule(oneflow.nn.Module):
  3. def __init__(self):
  4. super().__init__()
  5. self.linear = oneflow.nn.Linear(512, 512)
  6. def forward(self, x):
  7. x = self.linear(x)
  8. y = oneflow.ones([2, 3])
  9. x = oneflow.relu(x)
  10. return y
  11. m = MyModule()
  12. traced = oneflow.fx.symbolic_trace(m)
  13. print(traced.code)
  14. """
  15. def forward(self, x):
  16. linear = self.linear(x); x = None
  17. relu = oneflow.relu(linear); linear = None
  18. _tensor_constant0 = self._tensor_constant0
  19. return _tensor_constant0
  20. """

函数代理

代理,即fx中的Proxy模块,目的是在每次进行函数或模块调用的时候添加一些额外操作,使得对模型的解析和重建得以进行,而包装则是适配代理的一种方式。

torch.fx中,对于nn.Module的包装比较易于理解,每当待解析Module中出现了继承自nn.Module的对象,那么就将其__call__函数替换成包装过的函数。然而,对于pytorch的函数的代理的实现要更“绕”一些,是借助了__torch_function__这一机制

https://github.com/pytorch/pytorch/blob/c7c723897658eda6298bb74d92e4bb18ab4a5fe3/torch/overrides.py),限于篇幅原因这里不专门对其进行介绍。比较关键的点是,OneFlow中没有这一机制,如果需要添加,那么会是规模很大的、侵入性的,于是One-fx的实现就需要找其它路径。

我们使用的解决方式是搜索oneflow,oneflow.nn.functional,oneflow._C等模块中的Callable,并去除其中属于类的部分,然后对其余函数进行包装,在每次解析模型之前,会将这些模块的__dict__中对应项替换成包装后的函数,并且在解析模型之后重新将这些项进行还原。对于constructor类型的函数,如ones,randn等则不进行代理,直接运行,在最终构建图的时候作为constant来处理。

对于函数的包装部分源码实现如下,每次运行代理后的函数,会先判断该函数的入参中有没有Proxy变量,如果有,那么将会创建一个call_function类型的节点并返回Proxy包装后的节点,否则直接调用原函数并返回结果。

  1. def _create_wrapped_func(orig_fn):
  2. @functools.wraps(orig_fn)
  3. def wrapped(*args, **kwargs):
  4. # 判断参数中是否存在proxy变量
  5. proxy = _find_proxy(args, kwargs)
  6. if proxy is not None:
  7. # 如果参数中有Proxy变量,创建节点并返回Proxy包装后的节点
  8. return_proxy = proxy.tracer.create_proxy(
  9. "call_function", orig_fn, args, kwargs
  10. )
  11. return_proxy.node.meta["is_wrapped"] = True
  12. return return_proxy
  13. # 如果没有Proxy变量,直接调用原函数
  14. return orig_fn(*args, **kwargs)
  15. return wrapped

其中,return_proxy = proxy.tracer.create_proxy("call_function", orig_fn, args, kwargs)这行代码指定了使用与入参相同的Tracer来创建节点并返回结果,create_proxy函数定义的主要部分如下,创建节点并在Proxy包装后返回。

  1. def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
  2. name: Optional[str] = None, type_expr : Optional[Any] = None,
  3. proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
  4. args_ =
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/377543
推荐阅读
相关标签
  

闽ICP备14008679号