赞
踩
1. Pytorch导出ONNX
torch.onnx.export函数实现了pytorch模型到onnx模型的导出,在pytorch1.11.0中,torch.onnx.export函数参数如下:
- def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
- input_names=None, output_names=None, aten=False, export_raw_ir=False,
- operator_export_type=None, opset_version=None, _retain_param_name=True,
- do_constant_folding=True, example_outputs=None, strip_doc_string=True,
- dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,
- enable_onnx_checker=True, use_external_data_format=False):
参数比较多,但常用的有如下几个:
- model: pytorch模型
-
- args: 第一个参数model的输入数据,因为模型的输入可能不止一个,因此采用元组作为参数
-
- f: 导出的onnx模型文件路径
-
- export_params: 导出的onnx模型文件可以包含网络结构与权重参数,如果设置该参数为False,则导出的onnx模型文件只包含网络结构,因此,一般保持默认为True即可
-
- verbose: 该参数如果指定为True,则在导出onnx的过程中会打印详细的导出过程信息
-
- input_names: 为输入节点指定名称,因为输入节点可能多个,因此该参数是一个列表
-
- output_names: 为输出节点指定名称,因为输出节点可能多个,因此该参数是一个列表
-
- opset_version: 导出onnx时参考的onnx算子集版本
-
- dynamic_axes: 指定输入输出的张量,哪些维度是动态的,通过用字典的形式进行指定,如果某个张量的某个维度被指定为字符串或者-1,则认为该张量的该维度是动态的,但是一般建议只对batch维度指定动态,这样可提高性能,具体的格式见下面的代码
如下代码,定义了一个包含卷积层、relu激活层的网络,将该网络导出onnx模型,设置了输入、输出的batch、height、width3个维度是动态的
- import torch
- import torch.nn as nn
- import torch.onnx
- import os
-
- # 定义一个模型
- class Model(torch.nn.Module):
- def __init__(self):
- super().__init__()
-
- self.conv = nn.Conv2d(1, 1, 3, padding=1)
- self.relu = nn.ReLU()
- self.conv.weight.data.fill_(1) # 权重被初始化为1
- self.conv.bias.data.fill_(0) # 偏置被初始化为0
-
- def forward(self, x):
- x = self.conv(x)
- x = self.relu(x)
- return x
-
-
- model = Model()
- dummy = torch.zeros(1, 1, 3, 3)
-
- torch.onnx.export(
- model,
-
- # 输入给model的数据,因为是元组类型,因此用括号
- (dummy,),
-
- # 导出的onnx文件路径
- "demo.onnx",
-
- # 打印导出过程详细信息
- verbose=True,
-
- # 为输入和输出节点指定名称,方便后面查看或者操作
- input_names=["image"],
- output_names=["output"],
-
- # 导出时参考的onnx算子集版本
- opset_version=11,
-
- # 设置batch、height、width3个维度是动态的,
- # 在onnx中会将其维度赋值为-1,
- # 通常,我们只设置batch为动态,其它的避免动态
- dynamic_axes={
- "image": {0: "batch", 2: "height", 3: "width"},
- "output": {0: "batch", 2: "height", 3: "width"},
- }
- )
-
- print("Done.!")
2. netron可视化
netron可视化可以看到网络输入层为image,输出层为output,这些层名都是在onnx导出时指定的,另外红色框标注处,显示batch、height、width三个维度为动态的。
3. 修改onnx模型
1)修改模型输入尺寸
(1):动态尺寸修改为静态尺寸
- import onnx
- import onnxruntime as rt
- import os
- import numpy as np
- import argparse
-
- class fix_dim_tools:
- def __init__(self, model_path, inputs_shape, inputs_dtype):
- assert os.path.exists(model_path), "{} not exists".format(model_path)
- if inputs_dtype is None:
- print('inputs_dtype is not define, use float for all inputs node')
- inputs_dtype = ['float']*len(inputs_shape)
- else:
- assert len(inputs_shape)==len(inputs_dtype), "inputs shape list should have same length as inputs_dtype"
-
- model = onnx.load(model_path)
- self.model = model
- self.model_path = model_path
- self.inputs_shape = inputs_shape
- self.inputs_dtype = inputs_dtype
-
- self.inputs_shape_dict = {}
- self.inputs_type_dict = {}
- self.outputs_shape_dict = {}
-
- def check_dynamic_input(self):
- # check dynamic input and get real input shape
- inputs_number = len(self.model.graph.input)
- assert inputs_number==len(self.inputs_shape),"model has {} inputs, but {} inputs_shape was given, not match".format(inputs_number,len(self.inputs_shape))
- state = False
-
- for i in range(inputs_number):
- _input = self.model.graph.input[i]
- dim_values = [dim.dim_value for dim in _input.type.tensor_type.shape.dim]
- if 0 in dim_values:
- state = True
- print('Input node:{} is dynamic input, the shape info is {}. Using given shape-{} instead.'.format(_input.name, dim_values, self.inputs_shape[i]))
- self.inputs_shape_dict[_input.name] = self.inputs_shape[i]
- else:
- print('Input node:{} is normal input, the shape info is {}. Ignore given shape-{}'.format(_input.name, dim_values, self.inputs_shape[i]))
- self.inputs_shape_dict[_input.name] = dim_values
- self.inputs_type_dict[_input.name] = self.inputs_dtype[i]
- return state
-
- def run_onnxruntime_to_get_output_shape(self):
- sess = rt.InferenceSession(self.model_path)
- inputs_dict = {}
-
- # generate fake input
- for key in self.inputs_shape_dict.keys():
- if self.inputs_type_dict[key] == 'int':
- inputs_dict[key] = np.random.randint(0,255,self.inputs_shape_dict[key]).astype(np.uint8)
- elif self.inputs_type_dict[key] == 'float':
- inputs_dict[key] = np.random.randn(*self.inputs_shape_dict[key]).astype(np.float32)
-
- outputs_name = [node.name for node in sess.get_outputs()]
- res = sess.run(outputs_name, inputs_dict)
- for i in range(len(sess.get_outputs())):
- node_name = sess.get_outputs()[i].name
- self.outputs_shape_dict[node_name] = list(res[i].shape)
- print("After inference, output:{} shape is [{}]".format(node_name, list(res[i].shape)))
-
- def run(self):
- dynamic_input = self.check_dynamic_input()
-
- if dynamic_input is False:
- print("{} hasn't dynamic input, no file will be generated.".format(self.model_path))
- return False
- else:
- self.run_onnxruntime_to_get_output_shape()
-
- try:
- from onnx.tools import update_model_dims
- update_model_dims.update_inputs_outputs_dims(self.model, self.inputs_shape_dict, self.outputs_shape_dict)
- return True
- except Exception as e:
- print("Automatically fix failed. Please try to export an non-dynamic onnx model first. Exception: {}".format(repr(e)))
- return False
-
- def export(self,output_path):
- onnx.save(self.model, output_path)
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description="This script can fix the dim for onnx model which have dynamic inputs.\n"
- "Using this file such as: Python freeze_inshape_for_onnx_model.py "
- "--onnx_file_path model_in.onnx "
- "--output_path model_out.onnx "
- "--inputs_shape 1,3,112,112#1,3,64,64 "
- "--inputs_dtype float#int")
- parser.add_argument('--onnx_file_path', type=str, default='./models/pplcnet.onnx', help='input model path')
- parser.add_argument('--output_path', type=str, default='./models/pplcnet_freeze.onnx', help='output model path')
- parser.add_argument('--inputs_shape', type=str, default='1,3,224,224', help='inputs shape list, write as 1,3,112,112#1,3,64,64 for multi-input.')
- parser.add_argument('--inputs_dtype', type=str, required=False, help='(Options) inputs shape list, write as float#int for multi-input. Defualt as float')
-
- args = parser.parse_args()
- inputs_shape = []
- inputs_shape_define = args.inputs_shape.split('#')
- for shape_str in inputs_shape_define:
- inputs_shape.append([int(dim) for dim in shape_str.split(',')])
-
- if args.inputs_dtype is None:
- inputs_dtype = None
- else:
- inputs_dtype = args.inputs_dtype.split('#')
-
- fdt = fix_dim_tools(args.onnx_file_path, inputs_shape, inputs_dtype)
- state = fdt.run()
- if state is True:
- fdt.export(args.output_path)
- print('Success! Fix dim done. Export new model to path:{}'.format(args.output_path))
(2)修改batch size
- import onnx
- import onnx_graphsurgeon as gs
-
- if __name__ == '__main__':
-
- model_path = "./models/onnx/pplcnet_1b.onnx"
- output_model_path = "./models/onnx/pplcnet_4b.onnx"
- new_batch_size = 4 # 新的批处理大小
-
- model = onnx.load(model_path)
-
- graph = gs.import_onnx(model)
-
- input_node = graph.inputs[0]
- input_shape = input_node.shape
- input_shape[0] = new_batch_size
- input_node.shape = input_shape
-
- output_node = graph.outputs[0]
- output_shape = output_node.shape
- output_shape[0] = new_batch_size
- output_node.shape = output_shape
-
- onnx.save(gs.export_onnx(graph), output_model_path)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。