当前位置:   article > 正文

TorchScript的TracedModule和ScriptModule的区别_recursivescriptmodule

recursivescriptmodule

最近在读TorchScript的入门介绍,看了官方链接的文章,然后感觉是云山雾罩,不知所云。

然后搜索到了Rene Wang的文章,才感觉明白了好多。

官方的介绍TracedModule的缺点例子是这样的:

  1. class MyDecisionGate(torch.nn.Module):
  2. def forward(self, x):
  3. if x.sum() > 0:
  4. return x
  5. else:
  6. return -x
  7. class MyCell(torch.nn.Module):
  8. def __init__(self, dg):
  9. super(MyCell, self).__init__()
  10. self.dg = dg
  11. self.linear = torch.nn.Linear(4, 4)
  12. def forward(self, x, h):
  13. new_h = torch.tanh(self.dg(self.linear(x)) + h)
  14. return new_h, new_h
  15. my_cell = MyCell(MyDecisionGate())
  16. traced_cell = torch.jit.trace(my_cell, (x, h))
  17. print(traced_cell.code)

输出是:

  1. def forward(self,
  2. input: Tensor,
  3. h: Tensor) -> Tuple[Tensor, Tensor]:
  4. _0 = (self.dg).forward((self.linear).forward(input, ), )
  5. _1 = torch.tanh(torch.add(_0, h, alpha=1))
  6. return (_1, _1)

然后官方再介绍ScriptMoudle:

  1. scripted_gate = torch.jit.script(MyDecisionGate())
  2. my_cell = MyCell(scripted_gate)
  3. traced_cell = torch.jit.script(my_cell)
  4. print(traced_cell.code)

然后输出是:

  1. def forward(self,
  2. x: Tensor,
  3. h: Tensor) -> Tuple[Tensor, Tensor]:
  4. _0 = (self.dg).forward((self.linear).forward(x, ), )
  5. new_h = torch.tanh(torch.add(_0, h, alpha=1))
  6. return (new_h, new_h)

然后文章里就高潮叫hooray了,我还是一脸懵逼的,根本没有看到ScriptModule的code与TracedModule的code差异啊?

Rene Wang的文章解释的很到位,关键要看my_cell.dg.code,其实他们是这样的

  1. traced_gate = torch.jit.trace(my_cell.dg, (x,))
  2. print(traced_gate.code)
  3. --输出--
  4. c:\python36\lib\site-packages\ipykernel_launcher.py:4: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  5. after removing the cwd from sys.path.
  6. def forward(self,
  7. x: Tensor) -> Tensor:
  8. return x
  1. scripted_gate = torch.jit.script(MyDecisionGate())
  2. print(scripted_gate.code)
  3. my_cell = MyCell(scripted_gate)
  4. traced_cell = torch.jit.script(my_cell)
  5. print(traced_cell)
  6. print(traced_cell.code)
  7. #只有从dg.code才能看到 if else 流程控制语句执行了
  8. print(traced_cell.dg.code)
  9. --输出--
  10. def forward(self,
  11. x: Tensor) -> Tensor:
  12. _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  13. if _0:
  14. _1 = x
  15. else:
  16. _1 = torch.neg(x)
  17. return _1
  18. RecursiveScriptModule(
  19. original_name=MyCell
  20. (dg): RecursiveScriptModule(original_name=MyDecisionGate)
  21. (linear): RecursiveScriptModule(original_name=Linear)
  22. )
  23. def forward(self,
  24. x: Tensor,
  25. h: Tensor) -> Tuple[Tensor, Tensor]:
  26. _0 = (self.dg).forward((self.linear).forward(x, ), )
  27. new_h = torch.tanh(torch.add(_0, h, alpha=1))
  28. return (new_h, new_h)
  29. def forward(self,
  30. x: Tensor) -> Tensor:
  31. _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  32. if _0:
  33. _1 = x
  34. else:
  35. _1 = torch.neg(x)
  36. return _1

这样能够清晰的看到ScriptModule追踪到了if else 控制流。

基于torch 1.4.0版本,可能官方的tutorial是基于老的版本的实例。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/850140
推荐阅读
相关标签
  

闽ICP备14008679号