|
|
|
|
|
import onnx
|
|
|
from onnx import helper
|
|
|
from onnx import AttributeProto, TensorProto, GraphProto
|
|
|
import time
|
|
|
import sys
|
|
|
|
|
|
def check_update_model(model):
|
|
|
if (model.graph.name == ''):
|
|
|
model.graph.name = 'tackle_checker_error'
|
|
|
|
|
|
try:
|
|
|
onnx.checker.check_model(model)
|
|
|
except onnx.checker.ValidationError as e:
|
|
|
print("model is invalid: %s"%(e))
|
|
|
# 目前只考虑这个错误“ Nodes in a graph must be topologically sorted, however input ...”,
|
|
|
# 就是模型的节点顺序有问题,onnxruntime 是可以推断的,为了其他引擎比如 trt可以推断,调整一下节点顺序。
|
|
|
else:
|
|
|
time.sleep(1)
|
|
|
print("模型正常,返回原模型")
|
|
|
return model
|
|
|
|
|
|
print('update model ...')
|
|
|
new_model = onnx.ModelProto(ir_version=model.ir_version,
|
|
|
producer_name=model.producer_name,
|
|
|
producer_version=model.producer_version,
|
|
|
opset_import=model.opset_import)
|
|
|
|
|
|
new_model.graph.name = model.graph.name
|
|
|
# 只考虑模型是单输入单输出的情况。
|
|
|
# 这里不需要 copy.deepcopy(model.graph.input[0]), python 是写时copy。\
|
|
|
list_total_input_name = []
|
|
|
for inp in model.graph.input:
|
|
|
new_model.graph.input.append(inp)
|
|
|
list_total_input_name.append(inp.name)
|
|
|
for oup in model.graph.output:
|
|
|
new_model.graph.output.append(oup)
|
|
|
|
|
|
|
|
|
for weight in model.graph.initializer:
|
|
|
new_model.graph.initializer.append(weight)
|
|
|
list_total_input_name.append(weight.name)
|
|
|
|
|
|
list_nodes_invalid = []
|
|
|
for node in model.graph.node:
|
|
|
node_invalid = False
|
|
|
for node_input_name in node.input:
|
|
|
if (list_total_input_name.count(node_input_name) == 0):
|
|
|
list_nodes_invalid.append(node)
|
|
|
node_invalid = True
|
|
|
break
|
|
|
if (node_invalid):
|
|
|
continue
|
|
|
for node_output_name in node.output:
|
|
|
list_total_input_name.append(node_output_name)
|
|
|
new_model.graph.node.append(node)
|
|
|
|
|
|
for node in list_nodes_invalid:
|
|
|
for node_input_name in node.input:
|
|
|
assert list_total_input_name.count(node_input_name) > 0
|
|
|
for node_output_name in node.output:
|
|
|
list_total_input_name.append(node_output_name)
|
|
|
new_model.graph.node.append(node)
|
|
|
|
|
|
return new_model
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
onnx_in_name = sys.argv[1]
|
|
|
onnx_out_name = sys.argv[2]
|
|
|
|
|
|
onnx_model = onnx.load(onnx_in_name)
|
|
|
new_model = check_update_model(onnx_model)
|
|
|
onnx.checker.check_model(new_model)
|
|
|
|
|
|
onnx.save(new_model, onnx_out_name) |