赞
踩
最近在读TorchScript的入门介绍,看了官方链接的文章,然后感觉是云山雾罩,不知所云。
然后搜索到了Rene Wang的文章,才感觉明白了好多。
官方的介绍TracedModule的缺点例子是这样的:
- class MyDecisionGate(torch.nn.Module):
- def forward(self, x):
- if x.sum() > 0:
- return x
- else:
- return -x
-
- class MyCell(torch.nn.Module):
- def __init__(self, dg):
- super(MyCell, self).__init__()
- self.dg = dg
- self.linear = torch.nn.Linear(4, 4)
-
- def forward(self, x, h):
- new_h = torch.tanh(self.dg(self.linear(x)) + h)
- return new_h, new_h
-
- my_cell = MyCell(MyDecisionGate())
- traced_cell = torch.jit.trace(my_cell, (x, h))
- print(traced_cell.code)
输出是:
- def forward(self,
- input: Tensor,
- h: Tensor) -> Tuple[Tensor, Tensor]:
- _0 = (self.dg).forward((self.linear).forward(input, ), )
- _1 = torch.tanh(torch.add(_0, h, alpha=1))
- return (_1, _1)
然后官方再介绍ScriptMoudle:
- scripted_gate = torch.jit.script(MyDecisionGate())
-
- my_cell = MyCell(scripted_gate)
- traced_cell = torch.jit.script(my_cell)
- print(traced_cell.code)
然后输出是:
- def forward(self,
- x: Tensor,
- h: Tensor) -> Tuple[Tensor, Tensor]:
- _0 = (self.dg).forward((self.linear).forward(x, ), )
- new_h = torch.tanh(torch.add(_0, h, alpha=1))
- return (new_h, new_h)
然后文章里就高潮叫hooray了,我还是一脸懵逼的,根本没有看到ScriptModule的code与TracedModule的code差异啊?
Rene Wang的文章解释的很到位,关键要看my_cell.dg.code,其实他们是这样的
- traced_gate = torch.jit.trace(my_cell.dg, (x,))
- print(traced_gate.code)
-
- --输出--
- c:\python36\lib\site-packages\ipykernel_launcher.py:4: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
- after removing the cwd from sys.path.
- def forward(self,
- x: Tensor) -> Tensor:
- return x
- scripted_gate = torch.jit.script(MyDecisionGate())
- print(scripted_gate.code)
- my_cell = MyCell(scripted_gate)
- traced_cell = torch.jit.script(my_cell)
- print(traced_cell)
- print(traced_cell.code)
- #只有从dg.code才能看到 if else 流程控制语句执行了
- print(traced_cell.dg.code)
-
- --输出--
-
- def forward(self,
- x: Tensor) -> Tensor:
- _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
- if _0:
- _1 = x
- else:
- _1 = torch.neg(x)
- return _1
-
- RecursiveScriptModule(
- original_name=MyCell
- (dg): RecursiveScriptModule(original_name=MyDecisionGate)
- (linear): RecursiveScriptModule(original_name=Linear)
- )
- def forward(self,
- x: Tensor,
- h: Tensor) -> Tuple[Tensor, Tensor]:
- _0 = (self.dg).forward((self.linear).forward(x, ), )
- new_h = torch.tanh(torch.add(_0, h, alpha=1))
- return (new_h, new_h)
-
- def forward(self,
- x: Tensor) -> Tensor:
- _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
- if _0:
- _1 = x
- else:
- _1 = torch.neg(x)
- return _1
这样能够清晰的看到ScriptModule追踪到了if else 控制流。
基于torch 1.4.0版本,可能官方的tutorial是基于老的版本的实例。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。