当前位置:   article > 正文

onnx模型如何增加或者去除里面node,即修改图方法_graph.node.remove

graph.node.remove

有时候我们通过pytorch导出onnx模型,需要修改一下onnx的图结构,怎么修改呢?

下面两个Python实例提供了修改思路。
Changing the graph is easier than recreating it with make_graph, just use append, remove and insert.参考https://github.com/onnx/onnx/issues/2259

onnx_model = onnx.load(onnxfile)
graph = onnx_model.graph
pads = onnx.helper.make_tensor('avg_pads', onnx.TensorProto.INT64, [8], np.zeros(8, dtype=int))
graph.initializer.append(pads)
node = graph.node[584]
new_node = onnx.helper.make_node(
    'Pad',
    name='__Pad_584_fixed',
    inputs=['675', 'avg_pads'],
    outputs=['676'],
    mode='constant'
)
graph.node.remove(node)
graph.node.insert(584, new_node)
# Fix Equals (replace with Not)
node = graph.node[322]
new_node = onnx.helper.make_node(
    'Not',
    name='__Not__Equal_322',
    inputs=['412'],
    outputs=['414'],
)
graph.node.remove(node)
graph.node.insert(322, new_node)

onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, onnxfile)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

来源:https://github.com/saurabh-shandilya/onnx-utils


# ------------------------------------------------
# ONNX Model Editor and Graph Extractor
# License under The MIT License
# Written by Saurabh Shandilya
# -----------------------------------------------

import onnx
from onnx import helper, checker
from onnx import TensorProto
import re
import argparse

def createGraphMemberMap(graph_member_list):
    member_map=dict();
    for n in graph_member_list:
        member_map[n.name]=n;
    return member_map


def split_io_list(io_list,new_names_all):
    #splits input/output list to identify removed, retained and totally new nodes    
    removed_names=[]
    retained_names=[]
    for n in io_list:
        if n.name not in new_names_all:                
            removed_names.append(n.name)              
        if n.name in new_names_all:                
            retained_names.append(n.name)                      
    new_names=list(set(new_names_all)-set(retained_names)) 
    return [removed_names,retained_names,new_names]
          
def traceDependentNodes(graph,name,node_input_names,node_map, initializer_map):
    # recurisvely traces all dependent nodes for a given output nodes in a graph    
    for n in graph.node:
        for noutput in n.output:       
            if (noutput == name) and (n.name not in node_input_names):
                # give node "name" is node n's output, so add node "n" to node_input_names list 
                node_input_names.append(n.name)
                if n.name in node_map.keys():
                    for ninput in node_map[n.name].input:
                        # trace input node's inputs 
                        node_input_names = traceDependentNodes(graph,ninput,node_input_names,node_map, initializer_map)                                        
    # don't forget the initializers they can be terminal inputs on a path.                    
    if name in initializer_map.keys():
        node_input_names.append(name)                    
    return node_input_names     
    
def onnx_edit(input_model, output_model, new_input_node_names, input_shape_map, new_output_node_names, output_shape_map, verify):
    """ edits and modifies an onnx model to extract a subgraph based on input/output node names and shapes.
    Arguments: 
        input_model: path of input onnx model
        output_model: path of output onnx model    
        new_input_node_names: list of input node names including list of original input nodes if they are to be retained.
            If the list is empty original input nodes are assumed. 
        input_shape_map: dictionary/map of input node names to corresponding shapes. Shapes are needed for model checker to pass.
        new_output_node_names: list of output node names, including list of original output nodes if they are to be retained
            If the list if empty original output nodes are assumed.
        output_shape_map: dictionary/map of output node names to corresponding shape. Shapes are needed for model checker to pass.
        verify: set to true if input and output models need to be verified.
    """
    # LOAD MODEL AND PREP MAPS
    model = onnx.load(input_model)
    graph = model.graph
    if(verify):
        print("input model Errors: ", onnx.checker.check_model(model))
    
    node_map = createGraphMemberMap(graph.node)
    input_map = createGraphMemberMap(graph.input)
    output_map = createGraphMemberMap(graph.output)
    initializer_map = createGraphMemberMap(graph.initializer)
       
    if not new_input_node_names:
        new_input_node_names = list(input_map)
    if not new_output_node_names:
        new_output_node_names = list(output_map)
       
    # MODIFY INPUTS
    # Break the graph based on the new input node names
    [removed_names,retained_names,new_names]=split_io_list(graph.input,new_input_node_names)
    for name in removed_names:
        if name in input_map.keys():
            graph.input.remove(input_map[name])                              
    for name in new_names:
        # If a new input name corresponds to an existing node, it implies that original node in the graph needs to be replaced with an input node
        # Exactly here the graph is broken
        if name in node_map.keys():
            graph.node.remove(node_map[name])
        if(name in input_shape_map.keys()):
            new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, input_shape_map[name])
        else:
            new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, None)    
        graph.input.extend([new_nv])
    node_map = createGraphMemberMap(graph.node)
    input_map = createGraphMemberMap(graph.input)    

    # MODIFY OUTPUTS
    # Break the graph based on the new output node names   
    [removed_names,retained_names,new_names]=split_io_list(graph.output,new_output_node_names)
    for name in removed_names:
        if name in output_map.keys():
            graph.output.remove(output_map[name])                              
    for name in new_names:
        if(name in output_shape_map.keys()):
            new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, output_shape_map[name])
        else:
            new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, None)
        graph.output.extend([new_nv])
    output_map = createGraphMemberMap(graph.output)      

    # CLEANUP NODES
    # Trace all dependent nodes for the current set of output nodes defined & prepare a list of invalid nodes
    valid_node_names=[]
    for new_output_node_name in new_output_node_names:
        valid_node_names=traceDependentNodes(graph,new_output_node_name,valid_node_names,node_map, initializer_map)
        valid_node_names=list(set(valid_node_names))
    invalid_node_names = list( (set(node_map.keys()) | set(initializer_map.keys())) - set(valid_node_names))
    # Remove all the invalid nodes from the graph               
    for name in invalid_node_names:
        if name in node_map.keys():
            graph.node.remove(node_map[name])        
        if name in initializer_map.keys():
            graph.initializer.remove(initializer_map[name])
        if name in input_map.keys():
            graph.input.remove(input_map[name])    

    # SAVE MODEL
    if(verify):    
        print("output model Errors: ", onnx.checker.check_model(model))
    onnx.save(model, output_model)

def parse_nodename_and_shape(name):
    # parses node names and shapes from input argument string
    inputs = []
    shapes = {}
    # input takes in most cases the format name:0, where 0 is the output number, and shapes
    # are appended to the same e.g. name:0[1,28,28,3]
    name_pattern = r"(?:([\w\d/\-\._:]+)(\[[\-\d,]+\])?),?"
    
    splits = re.split(name_pattern, name)
    for i in range(1, len(splits), 3):        
        inputs.append(splits[i])
        if splits[i + 1] is not None:
            shapes[splits[i]] = [int(n) for n in splits[i + 1][1:-1].split(",")]
    if not shapes:
        shapes = None
    return inputs, shapes    
    
    
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="input onnx model")
    parser.add_argument("output", help="output onnx model")
    parser.add_argument("--inputs", help="comma separated model input names appended with shapes, e.g. --inputs <nodename>[1,2,3],<nodename1>[1,2,3] ")
    parser.add_argument("--outputs", help="comma separated model output names appended with shapes, e.g. --outputs <nodename>[1,2,3],<nodename1>[1,2,3] ")    
    parser.add_argument('--skipverify', dest='skipverify', action='store_true',
                    help='skip verification of model. Useful if shapes are not known')
    args = parser.parse_args()
        
    if args.inputs:
        new_input_node_names, input_shape_map = parse_nodename_and_shape(args.inputs)
        #print(new_input_node_names)
        #print(input_shape_map)
    else: 
        new_input_node_names = []
        input_shape_map = {}
        
    if args.outputs:
        new_output_node_names, output_shape_map = parse_nodename_and_shape(args.outputs)
        #print(new_output_node_names)
        #print(output_shape_map)
    else:
        new_output_node_names = []
        output_shape_map = {}
        
    onnx_edit(args.input,args.output,new_input_node_names, input_shape_map, new_output_node_names, output_shape_map, not args.skipverify)
    
        
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/850215
推荐阅读
相关标签
  

闽ICP备14008679号