赞
踩
【GiantPandaCV导语】本文是ONNX2Pytorch思路分享以及onnx-simplifier新版简要介绍。ONNX2Pytorch工具已经测试了onnx model zoo中的大量分类模型并转换正确,欢迎使用,github地址:https://github.com/BBuf/onnx2nn。GiantPandaCV几个月前遭受恶意举报,今天终于解除封印了。感谢众多粉丝们的长期等待和支持,我们会在此继续分享学习经验。
ONNX作为微软的神经网络模型的开放格式被各个框架广泛应用,包括Pytroch,TensorFlow,OneFlow,Keras,Paddle等多种深度学习训练框架。因此,之前一直在思考一个问题,一个TensorFlow/MxNet/Keras导出来的ONNX模型是否可以借助ONNX被Pytorch框架使用呢?ONNX的理想是作为所有训练框架模型的中间表示,那么我们只需要再实现ONNX到各个框架的逆转就可以完成这件事情了。本工程的目的即是尝试支持ONNX转换到Pytorch,主要为了锻炼算子对齐和更深入的了解ONNX。先放一下github地址:https://github.com/BBuf/onnx2nn,欢迎关注。这个工程复用了https://github.com/ToriML/onnx2pytorch
的整体逻辑,解决了原始工程中遗留的大量BUG,支持了更多OP,实现了输入一个ONNX模型,返回一个torch.nn.Module
对象,并将这个torch.nn.Module
对应的Pytorch模型保存下来。
首先需要说明的是,在执行转换之前需要先过一遍onnx-simplifer对原始的ONNX模型进行简化,工程地址为:https://github.com/daquexian/onnx-simplifier 。为了使用方便,我将这个工具直接接入到了本工程,在后面的使用方法中可以看到。
然后这和项目的思路是非常简单的,直接遍历ONNX模型的计算节点(也即OP),把每个OP一对一的转换到Pytorch就可以了。核心代码地址为:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/operations.py#L20-L181。简单截图说明一下:
def convert_operations(onnx_model, batch_dim=0): """ Convert onnx model operations. Yields onnx's operator_id, opeartor_name and converted pytorch operator. Parameters ---------- onnx_model: onnx.ModelProto Loaded onnx model. batch_dim: int Usually 0 for computer vision models and 1 for NLP models. Returns ------- iterator: (op_id, op_name, op) """ weights = {tensor.name: tensor for tensor in onnx_model.graph.initializer} for i, node in enumerate(onnx_model.graph.node): # extract only useful inputs params = [weights[par_name] for par_name in node.input if par_name in weights] if node.op_type == "Conv": op = convert_layer(node, "Conv", params) elif node.op_type == "Relu": op = nn.ReLU(inplace=True) elif node.op_type == "LeakyRelu": op = nn.LeakyReLU(**extract_attributes(node), inplace=True) elif node.op_type == "Sigmoid": op = nn.Sigmoid() elif node.op_type == "MaxPool": op = convert_layer(node, "MaxPool") elif node.op_type == "AveragePool": op = convert_layer(node, "AvgPool") elif node.op_type == "Flatten": op = Flatten(**extract_attributes(node)) elif node.op_type == "Gemm": op = convert_linear_layer(node, params) op.feature_dim = batch_dim + 1 # Necessary for transformers elif node.op_type == "BatchNormalization": op = convert_batch_norm_layer(node, params=params) elif node.op_type == "InstanceNormalization": op = convert_instance_norm_layer(node, params=params) elif node.op_type == "Concat": op = Concat(**extract_attributes(node)) else pass op_name = "{}_{}".format(node.op_type, node.output[0]) op_id = node.output[0] yield op_id, op_name, op
可以看到通过遍历ONNX模型的所有计算节点并获取每个节点的信息(输入参数以及各种attribute)之后将其用Pytorch的对应OP写出来就完成了转换过程。里面涉及到的每个OP的具体转换过程比如权重,attribute参数的提取以及对应Pytorch的实现等可以直接查看源码,这里不详细展开。
在获得每个ONNX计算节点对应的Pytorch OP之后,我们需要根据ONNX的计算节点反应的拓扑关系把所有的Pytorch OP组合成一个完整的Pytorch的模型,这部分的代码实现在:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/model.py#L36-L131
。
在执行ONNX2Pytorch的过程中需要注意一些由于Pytorch和ONNX OP实现不一致而导致模型转换失败的情况,下面列举一下:
非对称Padding问题。在对alexnet和google-net进行转换时发现它们的卷积或者Max Pooling层经常会出现非对称Padding的情况,由于Pytorch的卷积和最大池化操作不支持不对称Padding操作,所以这个时候为了保证转换的等价,需要将这个非对称Padding的OP拆成nn.ConstantPad2d
+无Padding的原始OP。
count_include_pad问题。在对inception-net进行转换时发现到了最后一个Avg Pooling层时出现了精度严重下降,经过Debug发现,Pytorch的Avg Pooling层的count_include_pad默认为True。如果这个时候也是非对称的Padding,那么按照上面的处理方法拆分成ConstantPad2d+Avg Pooling
之后会丢失精度,因为这种情况下Avg Pooling无法知晓自己Padding了多少元素。如下图所示:
这个时候可以通过修改Kernel尺寸的方法来规避这个问题,在上面的例子中我们可以直接让kernel_shape等于(7-1=6,7-1=6)
并且省掉新增常量Pad的操作。
这两点的代码实现在:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/layer.py#L30-L91
然后我们看一下Pytorch的LRN层的API:
对比一下API的参数可以发现ONNX里面的bias对应的是Pytorch LRN里面的参数k,所以这里需要特殊处理一下,获取这个attribute的bias参数的值之后将其设为Pytorch LRN层里面的k参数的值。具体实现在:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/attribute.py#L132-L139
- onnx2pytorch onnx转pytorch代码实现
- onnx2pytorch.py onnx转pytorch测试代码
- convert_models.md 转换ONNX Model Zoo里面的模型对应的命令和结果记录
- README.md
使用下面的命令将各个训练框架导出的ONNX模型转换成Pytorch模型
python .\onnx2pytorch.py ...
参数列表如下:
--onnx_path
字符串,必选参数,代表onnx模型的路径--pytorch_path
字符串,必选参数,代表转换出的Pytorch模型保存路径--simplify_path
字符串,可选参数,代表ONNX模型简化(例如删除Dropout和常量OP)后保存的ONNX模型路径--input_shape
字符串,必选参数,代表ONNX模型的输入数据层的名字和维度信息python .\onnx2pytorch.py --onnx_path .\models\mobilenetv2-7.onnx --simplify_path .\models\mobilenetv2-7-simplify.onnx --pytorch_path .\models\mobilenetv2-7.pth --input_shape input:1,3,224,224
onnx2pytorch.py
里面的model = convert.ConvertModel(onnx_model, debug=False)
这行代码里面的debug
设置False重新运行模型即可定位到转换失败的OP,然后你可以在工程提出issue或者自己解决然后给本工程PR。基于ONNXRuntime和Pytorch推理之后特征值mse小于1e-7,视为转换成功
onnx-simplifer最近迎来了一次更新,这次更新是和onnxruntime一起更新的,小伙伴们要使用最新版本记得把onnxruntime更新到1.6.0哦。然后我去阅读了一下最新的onnx-simplifer,在上次的ONNX初探基础上,增加了一个递归函数fixed_point
,功能就是递归执行func_a和fun_b直到模型稳定,代码如下:
# 递归执行func_a和func_b直到模型稳定 def fixed_point(x: T, func_a: Callable[[T], T], func_b: Callable[[T], T]) -> T: """ Run `func_a` and `func_b` on `x` until func_b(func_a(x)) == x :param x: :param func_a: A function satisfying func_a(func_a(x)) == func_a(x) :param func_b: A function satisfying func_b(func_b(x)) == func_b(x) :return: the x that satisfies func_b(func_a(x)) == x """ x = func_a(x) x = func_b(x) while True: y = func_a(x) if y == x: # Since func_b(func_b(x)) == func_b(x), # we are already at the fixed point if # `y == x` return x x = y y = func_b(x) if y == x: return x x = y
我们看一下它是怎么应用的就可以了,注释如下:
def simplify(model: Union[str, onnx.ModelProto], check_n: int = 0, perform_optimization: bool = True, skip_fuse_bn: bool = False, input_shapes: Optional[TensorShapesWithOptionalKey] = None, skipped_optimizers: Optional[Sequence[str]] = None, skip_shape_inference=False, input_data: Optional[Tensors] = None, dynamic_input_shape: bool = False, custom_lib: Optional[str] = None) -> Tuple[onnx.ModelProto, bool]: """ :param model: onnx ModelProto object or file path :param check_n: The simplified model will be checked for `check_n` times by random inputs :param perform_optimization: Whether to run onnx optimizer on the model :param skip_fuse_bn: Skip fuse_bn_into_conv onnx optimizer :param input_shapes: If the model has dynamic input shape, user must pass a fixed input shape for generating random inputs and checking equality. (Also see "dynamic_input_shape" param) :param skipped_optimizers: Skip some specific onnx optimizers :param skip_shape_inference: Skip shape inference (sometimes shape inference will crash) :param input_data: Feed custom input data for checking if needed :param dynamic_input_shape: Indicates whether the input shape should be dynamic. Note that input_shapes is also needed even if dynamic_input_shape is True, the value of input_shapes will be used when generating random inputs for checking equality. If 'dynamic_input_shape' is False, the input shape in simplified model will be overwritten by the value of 'input_shapes' param. :param custom_lib: onnxruntime custom ops's shared library :return: A tuple (simplified model, success(True) or failed(False)) """ if input_shapes is None: input_shapes = {} if input_data is None: input_data = {} if type(model) == str: # 加载ONNX模型 model = onnx.load(model) assert(isinstance(model, onnx.ModelProto)) # 检查ONNX模型格式是否正确,图结构是否完整,节点是否正确等 onnx.checker.check_model(model) # 深拷贝一份原始ONNX模型 model_ori = copy.deepcopy(model) input_names = get_input_names(model) for input_name, data in input_data.items(): if input_name not in input_names: raise RuntimeError( 'The model doesn\'t have input named "{}"'.format(input_name)) shape = list(input_data[input_name].shape) # special case for single constant variables (with shape []) if len(shape) == 0: shape = [input_data[input_name].size] if input_name in input_shapes and shape != input_shapes[input_name]: raise RuntimeError('The shape of input_data[{}] is not the same with input_shape[{}]'.format( input_name, input_name)) elif input_name not in input_shapes: input_shapes[input_name] = shape # 检查核对输入节点 updated_input_shapes = check_and_update_input_shapes(model, input_shapes) def infer_shapes_and_optimize(model: onnx.ModelProto) -> onnx.ModelProto: # 做ONNX模型节点形状推断 def infer_shapes_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto: if not skip_shape_inference: model = infer_shapes(model) return model # 对ONNX模型进行optimizer def optimize_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto: if perform_optimization: model = optimize(model, skip_fuse_bn, skipped_optimizers) return model # 递归执行infer_shapes_if_applicable和optimize_if_applicable直到模型稳定 return fixed_point(model, infer_shapes_if_applicable, optimize_if_applicable) def constant_folding(model: onnx.ModelProto) -> onnx.ModelProto: # 获取模型的常量OP const_nodes = get_constant_nodes( model, dynamic_input_shape=dynamic_input_shape) # 获取所有的常量OP以及原始输出OP的特征值 res = forward_for_node_outputs(model, const_nodes, input_shapes=updated_input_shapes, input_data=input_data, custom_lib=custom_lib) # 清洗那些没有被onnxruntime推理的静态节点 const_nodes = clean_constant_nodes(const_nodes, res) # 移除常量OP,获得简化后的ONNX模型 model = eliminate_const_nodes(model, const_nodes, res) # 检查ONNX模型格式是否正确,图结构是否完整,节点是否正确等 onnx.checker.check_model(model) return model # 递归执行infer_shapes_and_optimize和constant_folding直到模型稳定 model = fixed_point(model, infer_shapes_and_optimize, constant_folding) # 重写模型的输入shape if not dynamic_input_shape: for name, input_shape in updated_input_shapes.items(): for ipt in model.graph.input: if ipt.name == name: for i, dim in enumerate(ipt.type.tensor_type.shape.dim): dim.dim_value = input_shape[i] # 检查核对输入节点 check_ok = check(model_ori, model, check_n, input_shapes=updated_input_shapes) return model, check_ok
现在onnx-simplifer在简化过程中会递归的去推断shape,折叠常量,以及optimizer。所以这个程序比较依赖各个操作都不出错,如果某一步发生错误,可能有qia住的风险哦。使用最新版onnx-simplifer前切记更新onnxruntime到最新版本,否则使用model zoo里面的mobilenet模型就会引发qia住这一现象。
了解更多onnx-simplifer,比如执行流程,每一步再干什么请看ONNX初探的文章以及大老师发布的onnx simplifier 和 optimizer。
BBuf只是API搬运工,onnxoptimizer和onnx-simplifer的作者大老师才是yyds。
之前写过和整理一些ONNX学习笔记,现在汇总如下,如果你是从模型部署来看ONNX,其实我个人认为看这些了解就差不多了,当然有新的想法我也会继续更新的(鸽。
欢迎关注GiantPandaCV, 在这里你将看到独家的深度学习分享,坚持原创,每天分享我们学习到的新鲜知识。( • ̀ω•́ )✧
有对文章相关的问题,或者想要加入交流群,欢迎添加BBuf微信:
为了方便读者获取资料以及我们公众号的作者发布一些Github工程的更新,我们成立了一个QQ群,二维码如下,感兴趣可以加入。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。