当前位置:   article > 正文

ONNX基本操作_onnx export

onnx export

1. Pytorch导出ONNX

torch.onnx.export函数实现了pytorch模型到onnx模型的导出,在pytorch1.11.0中,torch.onnx.export函数参数如下:

  1. def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
  2. input_names=None, output_names=None, aten=False, export_raw_ir=False,
  3. operator_export_type=None, opset_version=None, _retain_param_name=True,
  4. do_constant_folding=True, example_outputs=None, strip_doc_string=True,
  5. dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,
  6. enable_onnx_checker=True, use_external_data_format=False):

参数比较多,但常用的有如下几个:

  1. model: pytorch模型
  2. args: 第一个参数model的输入数据,因为模型的输入可能不止一个,因此采用元组作为参数
  3. f: 导出的onnx模型文件路径
  4. export_params: 导出的onnx模型文件可以包含网络结构与权重参数,如果设置该参数为False,则导出的onnx模型文件只包含网络结构,因此,一般保持默认为True即可
  5. verbose: 该参数如果指定为True,则在导出onnx的过程中会打印详细的导出过程信息
  6. input_names: 为输入节点指定名称,因为输入节点可能多个,因此该参数是一个列表
  7. output_names: 为输出节点指定名称,因为输出节点可能多个,因此该参数是一个列表
  8. opset_version: 导出onnx时参考的onnx算子集版本
  9. dynamic_axes: 指定输入输出的张量,哪些维度是动态的,通过用字典的形式进行指定,如果某个张量的某个维度被指定为字符串或者-1,则认为该张量的该维度是动态的,但是一般建议只对batch维度指定动态,这样可提高性能,具体的格式见下面的代码

如下代码,定义了一个包含卷积层、relu激活层的网络,将该网络导出onnx模型,设置了输入、输出的batch、height、width3个维度是动态的

  1. import torch
  2. import torch.nn as nn
  3. import torch.onnx
  4. import os
  5. # 定义一个模型
  6. class Model(torch.nn.Module):
  7. def __init__(self):
  8. super().__init__()
  9. self.conv = nn.Conv2d(1, 1, 3, padding=1)
  10. self.relu = nn.ReLU()
  11. self.conv.weight.data.fill_(1) # 权重被初始化为1
  12. self.conv.bias.data.fill_(0) # 偏置被初始化为0
  13. def forward(self, x):
  14. x = self.conv(x)
  15. x = self.relu(x)
  16. return x
  17. model = Model()
  18. dummy = torch.zeros(1, 1, 3, 3)
  19. torch.onnx.export(
  20. model,
  21. # 输入给model的数据,因为是元组类型,因此用括号
  22. (dummy,),
  23. # 导出的onnx文件路径
  24. "demo.onnx",
  25. # 打印导出过程详细信息
  26. verbose=True,
  27. # 为输入和输出节点指定名称,方便后面查看或者操作
  28. input_names=["image"],
  29. output_names=["output"],
  30. # 导出时参考的onnx算子集版本
  31. opset_version=11,
  32. # 设置batch、height、width3个维度是动态的,
  33. # 在onnx中会将其维度赋值为-1,
  34. # 通常,我们只设置batch为动态,其它的避免动态
  35. dynamic_axes={
  36. "image": {0: "batch", 2: "height", 3: "width"},
  37. "output": {0: "batch", 2: "height", 3: "width"},
  38. }
  39. )
  40. print("Done.!")

 2. netron可视化

 netron可视化可以看到网络输入层为image,输出层为output,这些层名都是在onnx导出时指定的,另外红色框标注处,显示batch、height、width三个维度为动态的。

3. 修改onnx模型

1)修改模型输入尺寸

(1):动态尺寸修改为静态尺寸

  1. import onnx
  2. import onnxruntime as rt
  3. import os
  4. import numpy as np
  5. import argparse
  6. class fix_dim_tools:
  7. def __init__(self, model_path, inputs_shape, inputs_dtype):
  8. assert os.path.exists(model_path), "{} not exists".format(model_path)
  9. if inputs_dtype is None:
  10. print('inputs_dtype is not define, use float for all inputs node')
  11. inputs_dtype = ['float']*len(inputs_shape)
  12. else:
  13. assert len(inputs_shape)==len(inputs_dtype), "inputs shape list should have same length as inputs_dtype"
  14. model = onnx.load(model_path)
  15. self.model = model
  16. self.model_path = model_path
  17. self.inputs_shape = inputs_shape
  18. self.inputs_dtype = inputs_dtype
  19. self.inputs_shape_dict = {}
  20. self.inputs_type_dict = {}
  21. self.outputs_shape_dict = {}
  22. def check_dynamic_input(self):
  23. # check dynamic input and get real input shape
  24. inputs_number = len(self.model.graph.input)
  25. assert inputs_number==len(self.inputs_shape),"model has {} inputs, but {} inputs_shape was given, not match".format(inputs_number,len(self.inputs_shape))
  26. state = False
  27. for i in range(inputs_number):
  28. _input = self.model.graph.input[i]
  29. dim_values = [dim.dim_value for dim in _input.type.tensor_type.shape.dim]
  30. if 0 in dim_values:
  31. state = True
  32. print('Input node:{} is dynamic input, the shape info is {}. Using given shape-{} instead.'.format(_input.name, dim_values, self.inputs_shape[i]))
  33. self.inputs_shape_dict[_input.name] = self.inputs_shape[i]
  34. else:
  35. print('Input node:{} is normal input, the shape info is {}. Ignore given shape-{}'.format(_input.name, dim_values, self.inputs_shape[i]))
  36. self.inputs_shape_dict[_input.name] = dim_values
  37. self.inputs_type_dict[_input.name] = self.inputs_dtype[i]
  38. return state
  39. def run_onnxruntime_to_get_output_shape(self):
  40. sess = rt.InferenceSession(self.model_path)
  41. inputs_dict = {}
  42. # generate fake input
  43. for key in self.inputs_shape_dict.keys():
  44. if self.inputs_type_dict[key] == 'int':
  45. inputs_dict[key] = np.random.randint(0,255,self.inputs_shape_dict[key]).astype(np.uint8)
  46. elif self.inputs_type_dict[key] == 'float':
  47. inputs_dict[key] = np.random.randn(*self.inputs_shape_dict[key]).astype(np.float32)
  48. outputs_name = [node.name for node in sess.get_outputs()]
  49. res = sess.run(outputs_name, inputs_dict)
  50. for i in range(len(sess.get_outputs())):
  51. node_name = sess.get_outputs()[i].name
  52. self.outputs_shape_dict[node_name] = list(res[i].shape)
  53. print("After inference, output:{} shape is [{}]".format(node_name, list(res[i].shape)))
  54. def run(self):
  55. dynamic_input = self.check_dynamic_input()
  56. if dynamic_input is False:
  57. print("{} hasn't dynamic input, no file will be generated.".format(self.model_path))
  58. return False
  59. else:
  60. self.run_onnxruntime_to_get_output_shape()
  61. try:
  62. from onnx.tools import update_model_dims
  63. update_model_dims.update_inputs_outputs_dims(self.model, self.inputs_shape_dict, self.outputs_shape_dict)
  64. return True
  65. except Exception as e:
  66. print("Automatically fix failed. Please try to export an non-dynamic onnx model first. Exception: {}".format(repr(e)))
  67. return False
  68. def export(self,output_path):
  69. onnx.save(self.model, output_path)
  70. if __name__ == '__main__':
  71. parser = argparse.ArgumentParser(description="This script can fix the dim for onnx model which have dynamic inputs.\n"
  72. "Using this file such as: Python freeze_inshape_for_onnx_model.py "
  73. "--onnx_file_path model_in.onnx "
  74. "--output_path model_out.onnx "
  75. "--inputs_shape 1,3,112,112#1,3,64,64 "
  76. "--inputs_dtype float#int")
  77. parser.add_argument('--onnx_file_path', type=str, default='./models/pplcnet.onnx', help='input model path')
  78. parser.add_argument('--output_path', type=str, default='./models/pplcnet_freeze.onnx', help='output model path')
  79. parser.add_argument('--inputs_shape', type=str, default='1,3,224,224', help='inputs shape list, write as 1,3,112,112#1,3,64,64 for multi-input.')
  80. parser.add_argument('--inputs_dtype', type=str, required=False, help='(Options) inputs shape list, write as float#int for multi-input. Defualt as float')
  81. args = parser.parse_args()
  82. inputs_shape = []
  83. inputs_shape_define = args.inputs_shape.split('#')
  84. for shape_str in inputs_shape_define:
  85. inputs_shape.append([int(dim) for dim in shape_str.split(',')])
  86. if args.inputs_dtype is None:
  87. inputs_dtype = None
  88. else:
  89. inputs_dtype = args.inputs_dtype.split('#')
  90. fdt = fix_dim_tools(args.onnx_file_path, inputs_shape, inputs_dtype)
  91. state = fdt.run()
  92. if state is True:
  93. fdt.export(args.output_path)
  94. print('Success! Fix dim done. Export new model to path:{}'.format(args.output_path))

(2)修改batch size

  1. import onnx
  2. import onnx_graphsurgeon as gs
  3. if __name__ == '__main__':
  4. model_path = "./models/onnx/pplcnet_1b.onnx"
  5. output_model_path = "./models/onnx/pplcnet_4b.onnx"
  6. new_batch_size = 4 # 新的批处理大小
  7. model = onnx.load(model_path)
  8. graph = gs.import_onnx(model)
  9. input_node = graph.inputs[0]
  10. input_shape = input_node.shape
  11. input_shape[0] = new_batch_size
  12. input_node.shape = input_shape
  13. output_node = graph.outputs[0]
  14. output_shape = output_node.shape
  15. output_shape[0] = new_batch_size
  16. output_node.shape = output_shape
  17. onnx.save(gs.export_onnx(graph), output_model_path)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/367246
推荐阅读
相关标签
  

闽ICP备14008679号