当前位置:   article > 正文

修改onnx模型node_在线修改onnx

在线修改onnx

一、先列出有价值的参考链接/可学习的链接

  • onnx_python_examples 范例(已fock):链接
  • 一个可视化操作onnx 的git【一些op没实现】:链接
  • 一个比较好的知乎学习:链接
  • 比较好的一个csdn上对修改onnx的总结:链接
  • 一个知乎的简单总结:链接
  • 另一个csdn总结:链接
  • 又一个github修改shape的code:链接

二、代码整理

2.1 修改输入输出节点名称以及模型名称

import onnx

onnx_model = onnx.load('/root/Desktop/qxc_0613/mt_hmr_hardnet_rename.onnx')
export_model = '/root/Desktop/qxc_0613/mt_hmr_hardnet_rename_final.onnx'

# endpoint_names = ['input:input.1', 'POSE:648', 'SHAPE:570', 'CAM:572']
endpoint_names = ['input:input.1', 'POSE:648', 'SHAPE:570', 'CAM:572']

model_name = 'human_params_estimation'
#
# 修改以endpoint_names中节点作为input/output节点对名称
for i in range(len(onnx_model.graph.node)):
  for j in range(len(onnx_model.graph.node[i].input)):
    for endpoint_name in endpoint_names:
      if onnx_model.graph.node[i].input[j] == endpoint_name.split(':')[1]: # 改之前对in为==,防止graph中名为空字符串在in判断中成立
        print('-'*60)
        print("node name: ", onnx_model.graph.node[i].name)
        print("node input-------: ", onnx_model.graph.node[i].input[j])
        print("node input all: ", onnx_model.graph.node[i].input)
        print("node output all: ", onnx_model.graph.node[i].output)

        onnx_model.graph.node[i].input[j] = endpoint_name.split(':')[0] # onnx_model.graph.node[i].input[j].split(':')[0]

  for j in range(len(onnx_model.graph.node[i].output)):
    for endpoint_name in endpoint_names:
      if onnx_model.graph.node[i].output[j] == endpoint_name.split(':')[1]:
        print('-'*60)
        print("node: ", onnx_model.graph.node[i].name)
        print("node output-----: ", onnx_model.graph.node[i].output[j])
        print("node input all: ", onnx_model.graph.node[i].input)
        print("node output all: ", onnx_model.graph.node[i].output)

        onnx_model.graph.node[i].output[j] = endpoint_name.split(':')[0]

# 修改endpoint_names名称
for i in range(len(onnx_model.graph.input)):
  for endpoint_name in endpoint_names:
    if onnx_model.graph.input[i].name == endpoint_name.split(':')[1]:
      print('-'*60)
      print(onnx_model.graph.input[i])
      onnx_model.graph.input[i].name = endpoint_name.split(':')[0]

for i in range(len(onnx_model.graph.output)):
  for endpoint_name in endpoint_names:
    if onnx_model.graph.output[i].name in endpoint_name.split(':')[1]:
      print('-'*60)
      print(onnx_model.graph.output[i])
      onnx_model.graph.output[i].name = endpoint_name.split(':')[0]

# 修改model_name
print("before modify onnx_model.graph.name is: ", onnx_model.graph.name)
onnx_model.graph.name = model_name
print("after modify onnx_model.graph.name is: ", onnx_model.graph.name)


# 保存模型
onnx.save(onnx_model, export_model)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

2.2 修改模型graph的initlizer以及node属性(包括check 模型)

import onnx
import numpy as np
import torch

def create_initializer_tensor(
        name: str,
        tensor_array: np.ndarray,
        data_type: onnx.TensorProto = onnx.TensorProto.FLOAT
) -> onnx.TensorProto:
    # (TensorProto)
    initializer_tensor = onnx.helper.make_tensor(
        name=name,
        data_type=data_type,
        dims=tensor_array.shape,
        vals=tensor_array.flatten().tolist())

    return initializer_tensor

def replace_initializer_node(graph):
    # find initializer node [onnx的常数节点可能是使用initializer表示的,也可能是使用Constant节点表示]
    for initid, initializer in enumerate(graph.initializer):
        print("######%s######" % initid)
        print(initializer)
        print('--------next initializer:')
        pass

    # modify initializer node which we need modify
    #operator: shape
    for initid, initializer in enumerate(graph.initializer):
        if initializer.name == '276':
            del graph.initializer[initid]
            tensor_arr_276 = np.array([1, 1, 85, 6400]).astype(np.int64)
            initializer_tensor_276 = create_initializer_tensor('276', tensor_arr_276, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_276)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '340':
            del graph.initializer[initid]
            tensor_arr_340 = np.array([1, 1, 85, 1600]).astype(np.int64)
            initializer_tensor_340 = create_initializer_tensor('340', tensor_arr_340, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_340)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '404':
            del graph.initializer[initid]
            tensor_arr_404 = np.array([1, 1, 85, 400]).astype(np.int64)
            initializer_tensor_404 = create_initializer_tensor('404', tensor_arr_404, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_404)
            break

    #operator: -----------slice_axes_3--1
    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '279':
            del graph.initializer[initid]
            tensor_arr_279 = np.array([3]).astype(np.int64)
            initializer_tensor_279 = create_initializer_tensor('279', tensor_arr_279, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_279)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '288':
            del graph.initializer[initid]
            tensor_arr_288 = np.array([3]).astype(np.int64)
            initializer_tensor_288 = create_initializer_tensor('288', tensor_arr_288, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_288)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '296':
            del graph.initializer[initid]
            tensor_arr_296 = np.array([3]).astype(np.int64)
            initializer_tensor_296 = create_initializer_tensor('296', tensor_arr_296, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_296)
            break

    #operator: -----------slice_axes_3--2
    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '343':
            del graph.initializer[initid]
            tensor_arr_343 = np.array([3]).astype(np.int64)
            initializer_tensor_343 = create_initializer_tensor('343', tensor_arr_343, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_343)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '352':
            del graph.initializer[initid]
            tensor_arr_352 = np.array([3]).astype(np.int64)
            initializer_tensor_352 = create_initializer_tensor('352', tensor_arr_352, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_352)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '360':
            del graph.initializer[initid]
            tensor_arr_360 = np.array([3]).astype(np.int64)
            initializer_tensor_360 = create_initializer_tensor('360', tensor_arr_360, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_360)
            break

    #operator: -----------slice_axes_3--3
    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '407':
            del graph.initializer[initid]
            tensor_arr_407 = np.array([3]).astype(np.int64)
            initializer_tensor_407 = create_initializer_tensor('407', tensor_arr_407, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_407)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '416':
            del graph.initializer[initid]
            tensor_arr_416 = np.array([3]).astype(np.int64)
            initializer_tensor_416 = create_initializer_tensor('416', tensor_arr_416, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_416)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '424':
            del graph.initializer[initid]
            tensor_arr_424 = np.array([3]).astype(np.int64)
            initializer_tensor_424 = create_initializer_tensor('424', tensor_arr_424, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_424)
            break

    #operator: -----------Add:B
    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '284':
            del graph.initializer[initid]
            nx_284 = 80
            ny_284 = 80
            yv, xv = torch.meshgrid([torch.arange(ny_284), torch.arange(nx_284)])
            grid_284 = torch.stack((xv, yv), 2).view(1, 1, 6400, 2).float()
            tensor_arr_284 = grid_284.numpy().astype(np.int64)
            initializer_tensor_284 = create_initializer_tensor('284', tensor_arr_284, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_284)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '348':
            del graph.initializer[initid]
            nx_348 = 40
            ny_348 = 40
            yv, xv = torch.meshgrid([torch.arange(ny_348), torch.arange(nx_348)])
            grid_348 = torch.stack((xv, yv), 2).view(1, 1, 1600, 2).float()
            tensor_arr_348 = grid_348.numpy().astype(np.int64)
            initializer_tensor_348 = create_initializer_tensor('348', tensor_arr_348, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_348)
            break

    for initid, initializer in enumerate(graph.initializer):
        if  initializer.name == '412':
            del graph.initializer[initid]
            nx_412 = 20
            ny_412 = 20
            yv, xv = torch.meshgrid([torch.arange(ny_412), torch.arange(nx_412)])
            grid_412 = torch.stack((xv, yv), 2).view(1, 1, 400, 2).float()
            tensor_arr_412 = grid_412.numpy().astype(np.int64)
            initializer_tensor_412 = create_initializer_tensor('412', tensor_arr_412, data_type=onnx.TensorProto.INT64)
            graph.initializer.append(initializer_tensor_412)
            break

    return graph

def modify_node_attribute(graph):
    for node_id, node in enumerate(graph.node):
        if node.name == "Transpose_121":
            for attr_id, attr in enumerate(node.attribute):
                print("attr.name:", attr.name)
                print("attr.type:", attr.type)
                # if attr.type == onnx.AttributeProto.AttributeType.INTS:
                #     print("attr.ints:", attr.ints)

                # replace or add attr
                if attr.name == "perm":
                    # attr.ints[4] = {0,2,3,1} #  you can also directly modify origin attr
                    pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
                    del node.attribute[attr_id]
                    node.attribute.extend([pas_attr])

        if node.name == "Transpose_177":
            for attr_id, attr in enumerate(node.attribute):
                if attr.name == "perm":
                    # attr.ints[4] = {0,2,3,1} #  you can also directly modify origin attr
                    pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
                    del node.attribute[attr_id]
                    node.attribute.extend([pas_attr])
        if node.name == "Transpose_233":
            for attr_id, attr in enumerate(node.attribute):
                if attr.name == "perm":
                    # attr.ints[4] = {0,2,3,1} #  you can also directly modify origin attr
                    pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
                    del node.attribute[attr_id]
                    node.attribute.extend([pas_attr])

    return graph




def main():
    #----------0. prepare work-----------
    input_onnx_path = "/root/Desktop/v6n.opt.onnx"
    output_onnx_path = "/root/Desktop/v6n.opt.modify1.onnx"

    #----------1.1 load onnx file-----------
    inpt_onnx_model = onnx.load(input_onnx_path)
    graph = inpt_onnx_model.graph

    # -------1.2 replace initializer node
    graph = replace_initializer_node(graph)
    graph = modify_node_attribute(graph)

    #------------save graph-------------------
    graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
    info_model = onnx.helper.make_model(graph)
    onnx_model = onnx.shape_inference.infer_shapes(info_model)
    onnx.checker.check_model(onnx_model)
    onnx.save(info_model, output_onnx_path)

if __name__ == '__main__':
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号