赞
踩
.pt和.pth只能在pytorch的框架中使用,但是有时我们需要在其他的框架使用模型或者使用模型可视化工具来展示模型(大部分对.pt格式不兼容),这时就需要用到.onnx模型形式来转换了。
- pip install onnx
- pip install onnxruntime 进行安装
pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数说明:
转换代码:
- import io
- import torch
- import torch.onnx
- from cvit import CViT# 导入你的模型
- import os
-
- print(os.getcwd())
-
-
- def test():
- model = CViT()#将你自己的模型导入进来
-
- pthfile = r'../weight/deepfake_cvit_gpu_inference_ep_50.pth'#模型预训练参数文件位置
- loaded_model = torch.load(pthfile, map_location='cpu')
-
- model.load_state_dict(loaded_model)
-
- # data type nchw
- input = torch.randn(1, 3, 224, 224)# 这里根据你自己模型的输入维度进行改变,若不会看维度可以搜搜教程很简单。
- input_names = ["input"]
- output_names = ["output"]
- torch.onnx.export(model, input, "cvit_model1.onnx", verbose=False, opset_version=12, input_names=input_names, output_names=output_names)
-
-
- if __name__ == "__main__":
- test()
我在上述代码中注释的部分是需要你自己修改的,按照注释修改即可
RuntimeError: Exporting the operator silu to ONNX opset version 12 is not supported. Please open a bug to request ONNX export support for the missing operator.
如果在运行过程中遇到这类错误只需要按照他给的建议,将上面标绿的version 12 修改为对应的版本即可。(代码的最后一段)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。