You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
2.7 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
import time
import sys
def check_update_model(model):
if (model.graph.name == ''):
model.graph.name = 'tackle_checker_error'
try:
onnx.checker.check_model(model)
except onnx.checker.ValidationError as e:
print("model is invalid: %s"%(e))
# 目前只考虑这个错误“ Nodes in a graph must be topologically sorted, however input ...”,
# 就是模型的节点顺序有问题onnxruntime 是可以推断的,为了其他引擎比如 trt可以推断调整一下节点顺序。
else:
time.sleep(1)
print("模型正常,返回原模型")
return model
print('update model ...')
new_model = onnx.ModelProto(ir_version=model.ir_version,
producer_name=model.producer_name,
producer_version=model.producer_version,
opset_import=model.opset_import)
new_model.graph.name = model.graph.name
# 只考虑模型是单输入单输出的情况。
# 这里不需要 copy.deepcopy(model.graph.input[0]), python 是写时copy。\
list_total_input_name = []
for inp in model.graph.input:
new_model.graph.input.append(inp)
list_total_input_name.append(inp.name)
for oup in model.graph.output:
new_model.graph.output.append(oup)
for weight in model.graph.initializer:
new_model.graph.initializer.append(weight)
list_total_input_name.append(weight.name)
list_nodes_invalid = []
for node in model.graph.node:
node_invalid = False
for node_input_name in node.input:
if (list_total_input_name.count(node_input_name) == 0):
list_nodes_invalid.append(node)
node_invalid = True
break
if (node_invalid):
continue
for node_output_name in node.output:
list_total_input_name.append(node_output_name)
new_model.graph.node.append(node)
for node in list_nodes_invalid:
for node_input_name in node.input:
assert list_total_input_name.count(node_input_name) > 0
for node_output_name in node.output:
list_total_input_name.append(node_output_name)
new_model.graph.node.append(node)
return new_model
if __name__ == "__main__":
onnx_in_name = sys.argv[1]
onnx_out_name = sys.argv[2]
onnx_model = onnx.load(onnx_in_name)
new_model = check_update_model(onnx_model)
onnx.checker.check_model(new_model)
onnx.save(new_model, onnx_out_name)