当前位置:   article > 正文

PyTorch模型转ONNX格式_vgg16使用pytorch训练的pth模型怎么转成onnx模型

vgg16使用pytorch训练的pth模型怎么转成onnx模型

前言

在使用PyTorch进行网络训练得到.pth模型文件后,我们可能会做一些模型部署和加速的工作。这里一般会涉及到将PyTorch模型转为ONNX模型的过程。
在这里插入图片描述
PyTorch自带了ONNX转换方法(torch.onnx.export),可以很方便的将一些仅包含通用算子的网络的PyTorch模型转为ONNX格式。

转换步骤

准备模型文件

将PyTorch模型文件准备好,放在’./weights/torch.pth’路径下。

编写转换代码

__author__ = 'TracelessLe'

import torch

TORCH_WEIGHT_PATH = './weights/torch.pth'
ONNX_MODEL_PATH = 'net_bs8_v1.onnx'

def get_numpy_data():
	batch_size = 8
    img_input = np.ones((batch_size, 3, 128, 128), dtype=np.float32)
    return img_input

def get_torch_model():
    # Load Network Here
    pass 

def torch2onnx(img_input, onnx_model_path, device_id=0):
    torch_model = get_torch_model()  # Network define
    device = 'cpu' if device_id < 0 else f'cuda:{device_id}'
    torch_model.to(device)
    torch_weights = torch.load(TORCH_WEIGHT_PATH)
    torch_model.load_state_dict(torch_weights)
    torch_model.eval()
    dummy_img = torch.Tensor(img_input).to(device)
    torch.onnx.export(
        torch_model,
        (dummy_img),
        onnx_model_path,
        input_names=['input'],
        output_names=['output'],
        export_params=True,
        verbose=True,
        do_constant_folding=False,  # or True
        opset_version=11
    )
    print("Generate ONNX file over!")

if __name__ == "__main__":
    img_input = get_numpy_data()
    torch2onnx(img_input, ONNX_MODEL_PATH)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

简化ONNX模型结构(可选)

生成的ONNX结构可能还有简化的空间,可以使用onnx-simplifier工具进一步优化。

__author__ = 'TracelessLe'

import onnx
from onnxsim import simplify

ONNX_MODEL_PATH = 'net_bs8_v1.onnx'
ONNX_SIM_MODEL_PATH = 'net_bs8_v1_simple.onnx'

if __name__ == "__main__":
    onnx_model = onnx.load(ONNX_MODEL_PATH)
    onnx_sim_model, check = simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(onnx_sim_model, ONNX_SIM_MODEL_PATH)
    print('ONNX file simplified!')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

对比结果

__author__ = 'TracelessLe'

import time
import torch
import onnxruntime
import numpy as np

def test_torch(img_input, device_id=0, loop=100):
    torch_model = get_torch_model()
    device = 'cpu' if device_id < 0 else f'cuda:{device_id}'
    torch_model.to(device)
    torch_weights = torch.load(TORCH_WEIGHT_PATH)
    torch_model.load_state_dict(torch_weights)
    torch_model.eval()
    dummy_img = torch.Tensor(img_input).to(device)
    batch_size = 8
    with torch.no_grad():
        out_img = torch_model(dummy_img)
    time1 = time.time()
    for i in range(loop):
        time_bs1 = time.time()
        with torch.no_grad():
            out_img = torch_model(dummy_img)
            out_img_numpy = out_img.detach().cpu().numpy()
        time_bs2 = time.time()
        time_use_pt_bs = time_bs2 - time_bs1
        print(f'PyTorch use time {time_use_pt_bs} for bs8')
    time2 = time.time()
    time_use_pt = time2-time1
    print(f'PyTorch use time {time_use_pt} for loop {loop}, FPS={loop*batch_size//time_use_pt}')
    return out_img_numpy

def test_onnx(inputs, loop=100):
    inputs = inputs.astype(np.float32)
    print(onnxruntime.get_device())
    sess = onnxruntime.InferenceSession(ONNX_SIM_MODEL_PATH)
    batch_size = 8
    time1 = time.time()
    for i in range(loop):
        time_bs1 = time.time()
        out_ort_img = sess.run(None, {sess.get_inputs()[0].name: inputs,})
        time_bs2 = time.time()
        time_use_onnx_bs = time_bs2 - time_bs1
        print(f'ONNX use time {time_use_onnx_bs} for bs8')
    time2 = time.time()
    time_use_onnx = time2-time1
    print(f'ONNX use time {time_use_onnx} for loop {loop}, FPS={loop*batch_size//time_use_onnx}')
    return out_ort_img

if __name__ == "__main__":
    img_input = get_numpy_data()
    out_ort_img = test_onnx(img_input, loop=1)[0]
    out_img_numpy = test_torch(img_input, loop=1)
    mse = np.square(np.subtract(out_ort_img, out_img_numpy)).mean()
    print('mse between pytorch and onnx result: ', mse)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

模型可视化(可选)

可以使用Netron可视化PyTorch、ONNX等模型结构。
在这里插入图片描述
通过Netron打开转换得到的ONNX以及简化后的ONNX文件,可以对比查看网络结构的变化。
在这里插入图片描述

其他说明

(1)do_constant_folding
在使用torch.onnx.export时可能会报错(与do_constant_folding相关),那么可以将do_constant_folding=True设为do_constant_folding=False

(2)自定义算子的转换
PyTorch转ONNX目前仅支持一些通用算子(见PyTorch Doc),自定义的算子在转出时会报错。可以通过改写一部分算子或者使用功能相近的通用算子替代。

(3)PyTorch转ONNX的方式
目前PyTorch转ONNX有两种方式,除了主流的torch.onnx.export外,还有torch.jit.trace。但是后一种还处于实验阶段,可能会遇到较多问题。

版权说明

本文为原创文章,独家发布在blog.csdn.net/TracelessLe。未经个人允许不得转载。如需帮助请email至tracelessle@163.com
在这里插入图片描述

参考资料

[1] ONNX | Home
[2] torch.onnx — PyTorch 1.9.0 documentation
[3] (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 1.9.0+cu102 documentation
[4] daquexian/onnx-simplifier: Simplify your onnx model
[5] lutzroeder/netron: Visualizer for neural network, deep learning, and machine learning models

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/578536
推荐阅读
相关标签
  

闽ICP备14008679号