当前位置:   article > 正文

pytorch模型(.pt、.pth)转onnx模型(.onnx)的方法详解_pytorch模型 .pth.tar转onnx

pytorch模型 .pth.tar转onnx

.pt和.pth只能在pytorch的框架中使用,但是有时我们需要在其他的框架使用模型或者使用模型可视化工具来展示模型(大部分对.pt格式不兼容),这时就需要用到.onnx模型形式来转换了。

1、首先你要安装 依赖库:onnx 和 onnxruntime

  1. pip install onnx
  2. pip install onnxruntime 进行安装

2、pytorch模型转换到onnx模型

pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export 

torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)

参数说明:

  • model——需要导出的pytorch模型
  • args——模型的输入参数,满足输入层的shape正确即可。
  • path——输出的onnx模型的位置。例如‘yolov5.onnx’。
  • export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
  • verbose——是否打印模型转换信息。default=False。
  • input_names——输入节点名称。default=None。
  • output_names——输出节点名称。default=None。
  • do_constant_folding——是否使用常量折叠,默认即可。default=True。
  • dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。

转换代码:

  1. import io
  2. import torch
  3. import torch.onnx
  4. from cvit import CViT# 导入你的模型
  5. import os
  6. print(os.getcwd())
  7. def test():
  8. model = CViT()#将你自己的模型导入进来
  9. pthfile = r'../weight/deepfake_cvit_gpu_inference_ep_50.pth'#模型预训练参数文件位置
  10. loaded_model = torch.load(pthfile, map_location='cpu')
  11. model.load_state_dict(loaded_model)
  12. # data type nchw
  13. input = torch.randn(1, 3, 224, 224)# 这里根据你自己模型的输入维度进行改变,若不会看维度可以搜搜教程很简单。
  14. input_names = ["input"]
  15. output_names = ["output"]
  16. torch.onnx.export(model, input, "cvit_model1.onnx", verbose=False, opset_version=12, input_names=input_names, output_names=output_names)
  17. if __name__ == "__main__":
  18. test()

我在上述代码中注释的部分是需要你自己修改的,按照注释修改即可

3、错误

RuntimeError: Exporting the operator silu to ONNX opset version 12 is not supported. Please open a bug to request ONNX export support for the missing operator.

如果在运行过程中遇到这类错误只需要按照他给的建议,将上面标绿的version 12 修改为对应的版本即可。(代码的最后一段)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/850180
推荐阅读
相关标签
  

闽ICP备14008679号