本文是ONNX2Pytorch思路分享以及onnx-simplifier新版简要介绍。ONNX2Pytorch工具已经测试了onnx model zoo中的大量分类模型并转换正确,欢迎使用,github地址:https://github.com/BBuf/onnx2nn。
首先需要说明的是,在执行转换之前需要先过一遍onnx-simplifer对原始的ONNX模型进行简化,工程地址为:https://github.com/daquexian/onnx-simplifier 。为了使用方便,我将这个工具直接接入到了本工程,在后面的使用方法中可以看到。
def convert_operations(onnx_model, batch_dim=0): """ Convert onnx model operations. Yields onnx's operator_id, opeartor_name and converted pytorch operator. Parameters ---------- onnx_model: onnx.ModelProto Loaded onnx model. batch_dim: int Usually 0 for computer vision models and 1 for NLP models. Returns ------- iterator: (op_id, op_name, op) """ weights = {tensor.name: tensor for tensor in onnx_model.graph.initializer} for i, node in enumerate(onnx_model.graph.node): # extract only useful inputs params = [weights[par_name] for par_name in node.input if par_name in weights] if node.op_type == "Conv": op = convert_layer(node, "Conv", params) elif node.op_type == "Relu": op = nn.ReLU(inplace=True) elif node.op_type == "LeakyRelu": op = nn.LeakyReLU(**extract_attributes(node), inplace=True) elif node.op_type == "Sigmoid": op = nn.Sigmoid() elif node.op_type == "MaxPool": op = convert_layer(node, "MaxPool") elif node.op_type == "AveragePool": op = convert_layer(node, "AvgPool") elif node.op_type == "Flatten": op = Flatten(**extract_attributes(node)) elif node.op_type == "Gemm": op = convert_linear_layer(node, params) op.feature_dim = batch_dim + 1 # Necessary for transformers elif node.op_type == "BatchNormalization": op = convert_batch_norm_layer(node, params=params) elif node.op_type == "InstanceNormalization": op = convert_instance_norm_layer(node, params=params) elif node.op_type == "Concat": op = Concat(**extract_attributes(node)) else pass op_name = "{}_{}".format(node.op_type, node.output[0]) op_id = node.output[0] yield op_id, op_name, op
在获得每个ONNX计算节点对应的Pytorch OP之后,我们需要根据ONNX的计算节点反应的拓扑关系把所有的Pytorch OP组合成一个完整的Pytorch的模型,这部分的代码实现在:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/model.py#L36-L131
在执行ONNX2Pytorch的过程中需要注意一些由于Pytorch和ONNX OP实现不一致而导致模型转换失败的情况,下面列举一下:
非对称Padding问题。在对alexnet和google-net进行转换时发现它们的卷积或者Max Pooling层经常会出现非对称Padding的情况,由于Pytorch的卷积和最大池化操作不支持不对称Padding操作,所以这个时候为了保证转换的等价,需要将这个非对称Padding的OP拆成nn.ConstantPad2d
count_include_pad问题。在对inception-net进行转换时发现到了最后一个Avg Pooling层时出现了精度严重下降,经过Debug发现,Pytorch的Avg Pooling层的count_include_pad默认为True。如果这个时候也是非对称的Padding,那么按照上面的处理方法拆分成ConstantPad2d+Avg Pooling
之后会丢失精度,因为这种情况下Avg Pooling无法知晓自己Padding了多少元素。如下图所示:
对比一下API的参数可以发现ONNX里面的bias对应的是Pytorch LRN里面的参数k,所以这里需要特殊处理一下,获取这个attribute的bias参数的值之后将其设为Pytorch LRN层里面的k参数的值。具体实现在:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/attribute.py#L132-L139
- onnx2pytorch onnx转pytorch代码实现
- onnx2pytorch.py onnx转pytorch测试代码
- convert_models.md 转换ONNX Model Zoo里面的模型对应的命令和结果记录
python .\onnx2pytorch.py ...
字符串,必选参数,代表ONNX模型的输入数据层的名字和维度信息python .\onnx2pytorch.py --onnx_path .\models\mobilenetv2-7.onnx --simplify_path .\models\mobilenetv2-7-simplify.onnx --pytorch_path .\models\mobilenetv2-7.pth --input_shape input:1,3,224,224
里面的model = convert.ConvertModel(onnx_model, debug=False)
# 递归执行func_a和func_b直到模型稳定 def fixed_point(x: T, func_a: Callable[[T], T], func_b: Callable[[T], T]) -> T: """ Run `func_a` and `func_b` on `x` until func_b(func_a(x)) == x :param x: :param func_a: A function satisfying func_a(func_a(x)) == func_a(x) :param func_b: A function satisfying func_b(func_b(x)) == func_b(x) :return: the x that satisfies func_b(func_a(x)) == x """ x = func_a(x) x = func_b(x) while True: y = func_a(x) if y == x: # Since func_b(func_b(x)) == func_b(x), # we are already at the fixed point if # `y == x` return x x = y y = func_b(x) if y == x: return x x = y
def simplify(model: Union[str, onnx.ModelProto], check_n: int = 0, perform_optimization: bool = True, skip_fuse_bn: bool = False, input_shapes: Optional[TensorShapesWithOptionalKey] = None, skipped_optimizers: Optional[Sequence[str]] = None, skip_shape_inference=False, input_data: Optional[Tensors] = None, dynamic_input_shape: bool = False, custom_lib: Optional[str] = None) -> Tuple[onnx.ModelProto, bool]: """ :param model: onnx ModelProto object or file path :param check_n: The simplified model will be checked for `check_n` times by random inputs :param perform_optimization: Whether to run onnx optimizer on the model :param skip_fuse_bn: Skip fuse_bn_into_conv onnx optimizer :param input_shapes: If the model has dynamic input shape, user must pass a fixed input shape for generating random inputs and checking equality. (Also see "dynamic_input_shape" param) :param skipped_optimizers: Skip some specific onnx optimizers :param skip_shape_inference: Skip shape inference (sometimes shape inference will crash) :param input_data: Feed custom input data for checking if needed :param dynamic_input_shape: Indicates whether the input shape should be dynamic. Note that input_shapes is also needed even if dynamic_input_shape is True, the value of input_shapes will be used when generating random inputs for checking equality. If 'dynamic_input_shape' is False, the input shape in simplified model will be overwritten by the value of 'input_shapes' param. :param custom_lib: onnxruntime custom ops's shared library :return: A tuple (simplified model, success(True) or failed(False)) """ if input_shapes is None: input_shapes = {} if input_data is None: input_data = {} if type(model) == str: # 加载ONNX模型 model = onnx.load(model) assert(isinstance(model, onnx.ModelProto)) # 检查ONNX模型格式是否正确,图结构是否完整,节点是否正确等 onnx.checker.check_model(model) # 深拷贝一份原始ONNX模型 model_ori = copy.deepcopy(model) input_names = get_input_names(model) for input_name, data in input_data.items(): if input_name not in input_names: raise RuntimeError( 'The model doesn\'t have input named "{}"'.format(input_name)) shape = list(input_data[input_name].shape) # special case for single constant variables (with shape []) if len(shape) == 0: shape = [input_data[input_name].size] if input_name in input_shapes and shape != input_shapes[input_name]: raise RuntimeError('The shape of input_data[{}] is not the same with input_shape[{}]'.format( input_name, input_name)) elif input_name not in input_shapes: input_shapes[input_name] = shape # 检查核对输入节点 updated_input_shapes = check_and_update_input_shapes(model, input_shapes) def infer_shapes_and_optimize(model: onnx.ModelProto) -> onnx.ModelProto: # 做ONNX模型节点形状推断 def infer_shapes_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto: if not skip_shape_inference: model = infer_shapes(model) return model # 对ONNX模型进行optimizer def optimize_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto: if perform_optimization: model = optimize(model, skip_fuse_bn, skipped_optimizers) return model # 递归执行infer_shapes_if_applicable和optimize_if_applicable直到模型稳定 return fixed_point(model, infer_shapes_if_applicable, optimize_if_applicable) def constant_folding(model: onnx.ModelProto) -> onnx.ModelProto: # 获取模型的常量OP const_nodes = get_constant_nodes( model, dynamic_input_shape=dynamic_input_shape) # 获取所有的常量OP以及原始输出OP的特征值 res = forward_for_node_outputs(model, const_nodes, input_shapes=updated_input_shapes, input_data=input_data, custom_lib=custom_lib) # 清洗那些没有被onnxruntime推理的静态节点 const_nodes = clean_constant_nodes(const_nodes, res) # 移除常量OP,获得简化后的ONNX模型 model = eliminate_const_nodes(model, const_nodes, res) # 检查ONNX模型格式是否正确,图结构是否完整,节点是否正确等 onnx.checker.check_model(model) return model # 递归执行infer_shapes_and_optimize和constant_folding直到模型稳定 model = fixed_point(model, infer_shapes_and_optimize, constant_folding) # 重写模型的输入shape if not dynamic_input_shape: for name, input_shape in updated_input_shapes.items(): for ipt in model.graph.input: if ipt.name == name: for i, dim in enumerate(ipt.type.tensor_type.shape.dim): dim.dim_value = input_shape[i] # 检查核对输入节点 check_ok = check(model_ori, model, check_n, input_shapes=updated_input_shapes) return model, check_ok
现在onnx-simplifer在简化过程中会递归的去推断shape,折叠常量,以及optimizer。所以这个程序比较依赖各个操作都不出错,如果某一步发生错误,可能有qia住的风险哦。使用最新版onnx-simplifer前切记更新onnxruntime到最新版本,否则使用model zoo里面的mobilenet模型就会引发qia住这一现象。
了解更多onnx-simplifer,比如执行流程,每一步再干什么请看ONNX初探的文章以及大老师发布的onnx simplifier 和 optimizer。

