赞
踩
2.1 修改输入输出节点名称以及模型名称
import onnx
onnx_model = onnx.load('/root/Desktop/qxc_0613/mt_hmr_hardnet_rename.onnx')
export_model = '/root/Desktop/qxc_0613/mt_hmr_hardnet_rename_final.onnx'
# endpoint_names = ['input:input.1', 'POSE:648', 'SHAPE:570', 'CAM:572']
endpoint_names = ['input:input.1', 'POSE:648', 'SHAPE:570', 'CAM:572']
model_name = 'human_params_estimation'
#
# 修改以endpoint_names中节点作为input/output节点对名称
for i in range(len(onnx_model.graph.node)):
for j in range(len(onnx_model.graph.node[i].input)):
for endpoint_name in endpoint_names:
if onnx_model.graph.node[i].input[j] == endpoint_name.split(':')[1]: # 改之前对in为==,防止graph中名为空字符串在in判断中成立
print('-'*60)
print("node name: ", onnx_model.graph.node[i].name)
print("node input-------: ", onnx_model.graph.node[i].input[j])
print("node input all: ", onnx_model.graph.node[i].input)
print("node output all: ", onnx_model.graph.node[i].output)
onnx_model.graph.node[i].input[j] = endpoint_name.split(':')[0] # onnx_model.graph.node[i].input[j].split(':')[0]
for j in range(len(onnx_model.graph.node[i].output)):
for endpoint_name in endpoint_names:
if onnx_model.graph.node[i].output[j] == endpoint_name.split(':')[1]:
print('-'*60)
print("node: ", onnx_model.graph.node[i].name)
print("node output-----: ", onnx_model.graph.node[i].output[j])
print("node input all: ", onnx_model.graph.node[i].input)
print("node output all: ", onnx_model.graph.node[i].output)
onnx_model.graph.node[i].output[j] = endpoint_name.split(':')[0]
# 修改endpoint_names名称
for i in range(len(onnx_model.graph.input)):
for endpoint_name in endpoint_names:
if onnx_model.graph.input[i].name == endpoint_name.split(':')[1]:
print('-'*60)
print(onnx_model.graph.input[i])
onnx_model.graph.input[i].name = endpoint_name.split(':')[0]
for i in range(len(onnx_model.graph.output)):
for endpoint_name in endpoint_names:
if onnx_model.graph.output[i].name in endpoint_name.split(':')[1]:
print('-'*60)
print(onnx_model.graph.output[i])
onnx_model.graph.output[i].name = endpoint_name.split(':')[0]
# 修改model_name
print("before modify onnx_model.graph.name is: ", onnx_model.graph.name)
onnx_model.graph.name = model_name
print("after modify onnx_model.graph.name is: ", onnx_model.graph.name)
# 保存模型
onnx.save(onnx_model, export_model)
2.2 修改模型graph的initlizer以及node属性(包括check 模型)
import onnx
import numpy as np
import torch
def create_initializer_tensor(
name: str,
tensor_array: np.ndarray,
data_type: onnx.TensorProto = onnx.TensorProto.FLOAT
) -> onnx.TensorProto:
# (TensorProto)
initializer_tensor = onnx.helper.make_tensor(
name=name,
data_type=data_type,
dims=tensor_array.shape,
vals=tensor_array.flatten().tolist())
return initializer_tensor
def replace_initializer_node(graph):
# find initializer node [onnx的常数节点可能是使用initializer表示的,也可能是使用Constant节点表示]
for initid, initializer in enumerate(graph.initializer):
print("######%s######" % initid)
print(initializer)
print('--------next initializer:')
pass
# modify initializer node which we need modify
#operator: shape
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '276':
del graph.initializer[initid]
tensor_arr_276 = np.array([1, 1, 85, 6400]).astype(np.int64)
initializer_tensor_276 = create_initializer_tensor('276', tensor_arr_276, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_276)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '340':
del graph.initializer[initid]
tensor_arr_340 = np.array([1, 1, 85, 1600]).astype(np.int64)
initializer_tensor_340 = create_initializer_tensor('340', tensor_arr_340, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_340)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '404':
del graph.initializer[initid]
tensor_arr_404 = np.array([1, 1, 85, 400]).astype(np.int64)
initializer_tensor_404 = create_initializer_tensor('404', tensor_arr_404, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_404)
break
#operator: -----------slice_axes_3--1
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '279':
del graph.initializer[initid]
tensor_arr_279 = np.array([3]).astype(np.int64)
initializer_tensor_279 = create_initializer_tensor('279', tensor_arr_279, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_279)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '288':
del graph.initializer[initid]
tensor_arr_288 = np.array([3]).astype(np.int64)
initializer_tensor_288 = create_initializer_tensor('288', tensor_arr_288, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_288)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '296':
del graph.initializer[initid]
tensor_arr_296 = np.array([3]).astype(np.int64)
initializer_tensor_296 = create_initializer_tensor('296', tensor_arr_296, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_296)
break
#operator: -----------slice_axes_3--2
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '343':
del graph.initializer[initid]
tensor_arr_343 = np.array([3]).astype(np.int64)
initializer_tensor_343 = create_initializer_tensor('343', tensor_arr_343, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_343)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '352':
del graph.initializer[initid]
tensor_arr_352 = np.array([3]).astype(np.int64)
initializer_tensor_352 = create_initializer_tensor('352', tensor_arr_352, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_352)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '360':
del graph.initializer[initid]
tensor_arr_360 = np.array([3]).astype(np.int64)
initializer_tensor_360 = create_initializer_tensor('360', tensor_arr_360, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_360)
break
#operator: -----------slice_axes_3--3
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '407':
del graph.initializer[initid]
tensor_arr_407 = np.array([3]).astype(np.int64)
initializer_tensor_407 = create_initializer_tensor('407', tensor_arr_407, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_407)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '416':
del graph.initializer[initid]
tensor_arr_416 = np.array([3]).astype(np.int64)
initializer_tensor_416 = create_initializer_tensor('416', tensor_arr_416, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_416)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '424':
del graph.initializer[initid]
tensor_arr_424 = np.array([3]).astype(np.int64)
initializer_tensor_424 = create_initializer_tensor('424', tensor_arr_424, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_424)
break
#operator: -----------Add:B
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '284':
del graph.initializer[initid]
nx_284 = 80
ny_284 = 80
yv, xv = torch.meshgrid([torch.arange(ny_284), torch.arange(nx_284)])
grid_284 = torch.stack((xv, yv), 2).view(1, 1, 6400, 2).float()
tensor_arr_284 = grid_284.numpy().astype(np.int64)
initializer_tensor_284 = create_initializer_tensor('284', tensor_arr_284, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_284)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '348':
del graph.initializer[initid]
nx_348 = 40
ny_348 = 40
yv, xv = torch.meshgrid([torch.arange(ny_348), torch.arange(nx_348)])
grid_348 = torch.stack((xv, yv), 2).view(1, 1, 1600, 2).float()
tensor_arr_348 = grid_348.numpy().astype(np.int64)
initializer_tensor_348 = create_initializer_tensor('348', tensor_arr_348, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_348)
break
for initid, initializer in enumerate(graph.initializer):
if initializer.name == '412':
del graph.initializer[initid]
nx_412 = 20
ny_412 = 20
yv, xv = torch.meshgrid([torch.arange(ny_412), torch.arange(nx_412)])
grid_412 = torch.stack((xv, yv), 2).view(1, 1, 400, 2).float()
tensor_arr_412 = grid_412.numpy().astype(np.int64)
initializer_tensor_412 = create_initializer_tensor('412', tensor_arr_412, data_type=onnx.TensorProto.INT64)
graph.initializer.append(initializer_tensor_412)
break
return graph
def modify_node_attribute(graph):
for node_id, node in enumerate(graph.node):
if node.name == "Transpose_121":
for attr_id, attr in enumerate(node.attribute):
print("attr.name:", attr.name)
print("attr.type:", attr.type)
# if attr.type == onnx.AttributeProto.AttributeType.INTS:
# print("attr.ints:", attr.ints)
# replace or add attr
if attr.name == "perm":
# attr.ints[4] = {0,2,3,1} # you can also directly modify origin attr
pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
del node.attribute[attr_id]
node.attribute.extend([pas_attr])
if node.name == "Transpose_177":
for attr_id, attr in enumerate(node.attribute):
if attr.name == "perm":
# attr.ints[4] = {0,2,3,1} # you can also directly modify origin attr
pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
del node.attribute[attr_id]
node.attribute.extend([pas_attr])
if node.name == "Transpose_233":
for attr_id, attr in enumerate(node.attribute):
if attr.name == "perm":
# attr.ints[4] = {0,2,3,1} # you can also directly modify origin attr
pas_attr = onnx.helper.make_attribute("perm", [0, 1, 3, 2])
del node.attribute[attr_id]
node.attribute.extend([pas_attr])
return graph
def main():
#----------0. prepare work-----------
input_onnx_path = "/root/Desktop/v6n.opt.onnx"
output_onnx_path = "/root/Desktop/v6n.opt.modify1.onnx"
#----------1.1 load onnx file-----------
inpt_onnx_model = onnx.load(input_onnx_path)
graph = inpt_onnx_model.graph
# -------1.2 replace initializer node
graph = replace_initializer_node(graph)
graph = modify_node_attribute(graph)
#------------save graph-------------------
graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
info_model = onnx.helper.make_model(graph)
onnx_model = onnx.shape_inference.infer_shapes(info_model)
onnx.checker.check_model(onnx_model)
onnx.save(info_model, output_onnx_path)
if __name__ == '__main__':
main()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。