赞
踩
本文我们将主要介绍PyTorch中自带的torch.onnx模块。该模块包含将模型导出到ONNX IR格式的函数。这些模型可以被ONNX库加载,然后将它们转换成可在其他深度学习框架上运行的模型。
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None)
参数:
model.state_dict().values()
指定。有两种方式可用于保存/加载pytorch模型 1)文件中保存模型结构和权重参数 2)文件只保留模型权重.
1)pytorch模型保存
- import torch
- torch.save(selfmodel,"save.pt")
2)pytorch模型加载
- import torch
- torch.load("save.pt")
1)pytorch模型保存
- import torch
- torch.save(selfmodel.state_dict(),"save.pt")
2)pytorch模型加载
selfmodel.load_state_dict(torch.load("save.pt"))
- import torch
- torch_model = torch.load("save.pt") # pytorch模型加载
- batch_size = 1 #批处理大小
- input_shape = (3,244,244) #输入数据
-
- # set the model to inference mode
- torch_model.eval()
-
- x = torch.randn(batch_size,*input_shape) # 生成张量
- export_onnx_file = "test.onnx" # 目的ONNX文件名
- torch.onnx.export(torch_model,
- x,
- export_onnx_file,
- opset_version=10,
- do_constant_folding=True, # 是否执行常量折叠优化
- input_names=["input"], # 输入名
- output_names=["output"], # 输出名
- dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
- "output":{0:"batch_size"}})
注:dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.
- import torch
- torch_model = selfmodel() # 由研究员提供python.py文件
- batch_size = 1 # 批处理大小
- input_shape = (3, 244, 244) # 输入数据
-
- # set the model to inference mode
- torch_model.eval()
-
- x = torch.randn(batch_size,*input_shape) # 生成张量
- export_onnx_file = "test.onnx" # 目的ONNX文件名
- torch.onnx.export(torch_model,
- x,
- export_onnx_file,
- opset_version=10,
- do_constant_folding=True, # 是否执行常量折叠优化
- input_names=["input"], # 输入名
- output_names=["output"], # 输出名
- dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
- "output":{0:"batch_size"}})
-
链接:
1)pytorch官方文档.
2)参考
- pth->onnx常见问题
-
-
- ##模型输入输出不支持字典
- 在使用torch.onnx导出onnx格式时,模型的输入和输出都不支持字典类型的数据结构。
-
-
- **解决方法:**
- 此时,可以将字典数据结构换为torch.onnx支持的列表或者元组。例如:
- heads {'hm': 1, 'wh': 2, 'hps': 34, 'reg': 2, 'hm_hp': 17, 'hp_offset': 2}
- Centerpose中的字典在导出onnx时将会报错,可以将字典拆为两个列表,然后模型代码中做相应修改。
-
-
- ##tensor.size(dim)操作导致的转出的onnx操作中带有Gather[0]
- 如果在网络中存在类似tensor.size(dim)或者tensor.shape[1]这种索引操作,就会在转出的onnx模型中产生出Gather[0]操作,而这个操作,在TensorRT中是不支持的。
-
-
- **解决方法:**
- 尽量避免类似的操作,代码中不要出现与网络操作无关的tensor,一些尺寸之类的常数全部写成普通变量。
-
-
- ##求余算符在导出onnx时不支持
- 曾经在导出求余算符‘%’时出现如下报错,是因为此运算符为ATen操作符,但是未在torch/onnx/symbolic_opset9.py中定义。
- /usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:562: UserWarning: ONNX export failed on ATen operator remainder because torch.onnx.symbolic_opset9.remainder does not exist
- .format(op_name, opset_version, op_name))
- Traceback (most recent call last):
- File "transfer_centernet_dlav0_to_onnx2.py", line 65, in <module>
- onnx_module = torch.onnx.export(model,img,'multipose_dlav0_1x_modellast_process.onnx',verbose=True)
- File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 132, in export
- strip_doc_string, dynamic_axes)
- File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 64, in export
- example_outputs=example_outputs, strip_doc_string=strip_doc_string, dynamic_axes=dynamic_axes)
- File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 329, in _export
- _retain_param_name, do_constant_folding)
- File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 225, in _model_to_graph
- _disable_torch_constant_prop=_disable_torch_constant_prop)
- File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 127, in _optimize_graph
- graph = torch._C._jit_pass_onnx(graph, operator_export_type)
- File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 163, in _run_symbolic_function
- return utils._run_symbolic_function(*args, **kwargs)
- File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 563, in _run_symbolic_function
- op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
- File "/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_registry.py", line 91, in get_registered_op
- return _registry[(domain, version)][opname]
- KeyError: 'remainder'
-
-
- **解决方法:**
- 在torch/onnx/symbolic_opset9.py中补上自己的remainder的定义,添加的代码如下
- @parse_args( 'v', 'v')
- def remainder(g,input,division):
- return g.op("Remainder",input,division)
- "', since it's not constant, please try to make "
- RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible
-
- 解决方案:可能是pytorch版本不对,当前是pytorch1.5版本,但是换了好多版本还是不行,这个问题有待解决
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。