当前位置:   article > 正文

pytorch模型存储—转化为ONNX_torch模型导出onnx

torch模型导出onnx

本文我们将主要介绍PyTorch中自带的torch.onnx模块。该模块包含将模型导出到ONNX IR格式的函数。这些模型可以被ONNX库加载,然后将它们转换成可在其他深度学习框架上运行的模型。

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None)

参数:

  • model(torch.nn.Module)-要被导出的模型
  • args(参数的集合)-模型的输入,例如,这种model(*args)方式是对模型的有效调用。任何非Variable参数都将硬编码到导出的模型中;任何Variable参数都将成为导出的模型的输入,并按照他们在args中出现的顺序输入。如果args是一个Variable,这等价于用包含这个Variable的1-ary元组调用它。(注意:现在不支持向模型传递关键字参数。)
  • f-一个类文件的对象(必须实现文件描述符的返回)或一个包含文件名字符串。一个二进制Protobuf将会写入这个文件中。
  • export_params(bool,default True)-如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,导出的模型将首先把所有parameters作为参arguments,顺序由model.state_dict().values()指定。
  • verbose(bool,default False)-如果指定,将会输出被导出的轨迹的调试描述。
  • training(bool,default False)-导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True。
  • input_names(list of strings, default empty list)-按顺序分配名称到图中的输入节点。
  • output_names(list of strings, default empty list)-按顺序分配名称到图中的输出节点。

本篇幅介绍pytorch模型转ONNX模型

一、pytorch模型保存/加载

有两种方式可用于保存/加载pytorch模型 1)文件中保存模型结构和权重参数 2)文件只保留模型权重.

1、文件中保存模型结构和权重参数

1)pytorch模型保存

  1. import torch
  2. torch.save(selfmodel,"save.pt")

2)pytorch模型加载

  1. import torch
  2. torch.load("save.pt")

2、文件只保留模型权重

1)pytorch模型保存

  1. import torch
  2. torch.save(selfmodel.state_dict(),"save.pt")

2)pytorch模型加载

selfmodel.load_state_dict(torch.load("save.pt"))

二、pytorch模型转ONNX模型

1、文件中保存模型结构和权重参数

  1. import torch
  2. torch_model = torch.load("save.pt") # pytorch模型加载
  3. batch_size = 1 #批处理大小
  4. input_shape = (3,244,244) #输入数据
  5. # set the model to inference mode
  6. torch_model.eval()
  7. x = torch.randn(batch_size,*input_shape) # 生成张量
  8. export_onnx_file = "test.onnx" # 目的ONNX文件名
  9. torch.onnx.export(torch_model,
  10. x,
  11. export_onnx_file,
  12. opset_version=10,
  13. do_constant_folding=True, # 是否执行常量折叠优化
  14. input_names=["input"], # 输入名
  15. output_names=["output"], # 输出名
  16. dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
  17. "output":{0:"batch_size"}})

注:dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.

2、文件中只保留模型权重

  1. import torch
  2. torch_model = selfmodel() # 由研究员提供python.py文件
  3. batch_size = 1 # 批处理大小
  4. input_shape = (3, 244, 244) # 输入数据
  5. # set the model to inference mode
  6. torch_model.eval()
  7. x = torch.randn(batch_size,*input_shape) # 生成张量
  8. export_onnx_file = "test.onnx" # 目的ONNX文件名
  9. torch.onnx.export(torch_model,
  10. x,
  11. export_onnx_file,
  12. opset_version=10,
  13. do_constant_folding=True, # 是否执行常量折叠优化
  14. input_names=["input"], # 输入名
  15. output_names=["output"], # 输出名
  16. dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
  17. "output":{0:"batch_size"}})

链接:

1)pytorch官方文档.

2)参考

pytorch->onnx常见错误

  1. pth->onnx常见问题
  2. ##模型输入输出不支持字典
  3. 在使用torch.onnx导出onnx格式时,模型的输入和输出都不支持字典类型的数据结构。
  4. **解决方法:**
  5. 此时,可以将字典数据结构换为torch.onnx支持的列表或者元组。例如:
  6. heads {'hm': 1, 'wh': 2, 'hps': 34, 'reg': 2, 'hm_hp': 17, 'hp_offset': 2}
  7. Centerpose中的字典在导出onnx时将会报错,可以将字典拆为两个列表,然后模型代码中做相应修改。
  8. ##tensor.size(dim)操作导致的转出的onnx操作中带有Gather[0]
  9. 如果在网络中存在类似tensor.size(dim)或者tensor.shape[1]这种索引操作,就会在转出的onnx模型中产生出Gather[0]操作,而这个操作,在TensorRT中是不支持的。
  10. **解决方法:**
  11. 尽量避免类似的操作,代码中不要出现与网络操作无关的tensor,一些尺寸之类的常数全部写成普通变量。
  12. ##求余算符在导出onnx时不支持
  13. 曾经在导出求余算符‘%’时出现如下报错,是因为此运算符为ATen操作符,但是未在torch/onnx/symbolic_opset9.py中定义。
  14. /usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:562: UserWarning: ONNX export failed on ATen operator remainder because torch.onnx.symbolic_opset9.remainder does not exist
  15. .format(op_name, opset_version, op_name))
  16. Traceback (most recent call last):
  17. File "transfer_centernet_dlav0_to_onnx2.py", line 65, in <module>
  18. onnx_module = torch.onnx.export(model,img,'multipose_dlav0_1x_modellast_process.onnx',verbose=True)
  19. File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 132, in export
  20. strip_doc_string, dynamic_axes)
  21. File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 64, in export
  22. example_outputs=example_outputs, strip_doc_string=strip_doc_string, dynamic_axes=dynamic_axes)
  23. File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 329, in _export
  24. _retain_param_name, do_constant_folding)
  25. File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 225, in _model_to_graph
  26. _disable_torch_constant_prop=_disable_torch_constant_prop)
  27. File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 127, in _optimize_graph
  28. graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  29. File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 163, in _run_symbolic_function
  30. return utils._run_symbolic_function(*args, **kwargs)
  31. File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 563, in _run_symbolic_function
  32. op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
  33. File "/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_registry.py", line 91, in get_registered_op
  34. return _registry[(domain, version)][opname]
  35. KeyError: 'remainder'
  36. **解决方法:**
  37. 在torch/onnx/symbolic_opset9.py中补上自己的remainder的定义,添加的代码如下
  38. @parse_args( 'v', 'v')
  39. def remainder(g,input,division):
  40. return g.op("Remainder",input,division)
  1. "', since it's not constant, please try to make "
  2. RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible
  3. 解决方案:可能是pytorch版本不对,当前是pytorch1.5版本,但是换了好多版本还是不行,这个问题有待解决

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/367238
推荐阅读
相关标签
  

闽ICP备14008679号