diff --git a/onnx_resort_node.py b/onnx_resort_node.py new file mode 100644 index 0000000..f43029e --- /dev/null +++ b/onnx_resort_node.py @@ -0,0 +1,76 @@ + +import onnx +from onnx import helper +from onnx import AttributeProto, TensorProto, GraphProto +import time +import sys + +def check_update_model(model): + if (model.graph.name == ''): + model.graph.name = 'tackle_checker_error' + + try: + onnx.checker.check_model(model) + except onnx.checker.ValidationError as e: + print("model is invalid: %s"%(e)) + # 目前只考虑这个错误“ Nodes in a graph must be topologically sorted, however input ...”, + # 就是模型的节点顺序有问题,onnxruntime 是可以推断的,为了其他引擎比如 trt可以推断,调整一下节点顺序。 + else: + time.sleep(1) + print("模型正常,返回原模型") + return model + + print('update model ...') + new_model = onnx.ModelProto(ir_version=model.ir_version, + producer_name=model.producer_name, + producer_version=model.producer_version, + opset_import=model.opset_import) + + new_model.graph.name = model.graph.name + # 只考虑模型是单输入单输出的情况。 + # 这里不需要 copy.deepcopy(model.graph.input[0]), python 是写时copy。\ + list_total_input_name = [] + for inp in model.graph.input: + new_model.graph.input.append(inp) + list_total_input_name.append(inp.name) + for oup in model.graph.output: + new_model.graph.output.append(oup) + + + for weight in model.graph.initializer: + new_model.graph.initializer.append(weight) + list_total_input_name.append(weight.name) + + list_nodes_invalid = [] + for node in model.graph.node: + node_invalid = False + for node_input_name in node.input: + if (list_total_input_name.count(node_input_name) == 0): + list_nodes_invalid.append(node) + node_invalid = True + break + if (node_invalid): + continue + for node_output_name in node.output: + list_total_input_name.append(node_output_name) + new_model.graph.node.append(node) + + for node in list_nodes_invalid: + for node_input_name in node.input: + assert list_total_input_name.count(node_input_name) > 0 + for node_output_name in node.output: + list_total_input_name.append(node_output_name) + new_model.graph.node.append(node) + + return new_model + + +if __name__ == "__main__": + onnx_in_name = sys.argv[1] + onnx_out_name = sys.argv[2] + + onnx_model = onnx.load(onnx_in_name) + new_model = check_update_model(onnx_model) + onnx.checker.check_model(new_model) + + onnx.save(new_model, onnx_out_name) \ No newline at end of file