From 76c7de776aecd7ecd70e4a2ad4eef4a390855112 Mon Sep 17 00:00:00 2001 From: wangchunlin Date: Mon, 5 Dec 2022 15:39:29 +0800 Subject: [PATCH] =?UTF-8?q?e6ec=E6=98=AF=E6=9C=80=E5=90=8E=E4=B8=80?= =?UTF-8?q?=E7=89=88=E5=8F=AF=E4=BB=A5=E4=BD=BF=E7=94=A8=E7=9A=84=EF=BC=8C?= =?UTF-8?q?=E5=9C=A8onnx=5Fmodifier.py=E4=B8=AD=E5=BF=85=E9=A1=BB=E6=9C=89?= =?UTF-8?q?import=20parse=5Ftensor=E6=89=8D=E8=A1=8C=EF=BC=8C=E4=B8=8D?= =?UTF-8?q?=E7=84=B6=E4=B8=8B=E8=BD=BD=E4=B8=8D=E4=BA=86=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E7=9A=84=E6=A8=A1=E5=9E=8B=EF=BC=8Cwindows=E4=B8=8A=E7=9A=84ex?= =?UTF-8?q?e=E6=B2=A1=E6=9C=89=E8=AF=95=E4=B8=8D=E7=9F=A5=E9=81=93?= =?UTF-8?q?=EF=BC=8C=E5=86=8D=E5=A2=9E=E5=8A=A0=E4=B8=80=E4=B8=AAresort?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=EF=BC=8C=E5=BF=85=E9=A1=BB=E8=BF=87=E4=B8=80?= =?UTF-8?q?=E9=81=8D=EF=BC=8C=E4=B8=8D=E7=84=B6=E4=BC=9A=E5=87=BA=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- onnx_resort_node.py | 76 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 onnx_resort_node.py diff --git a/onnx_resort_node.py b/onnx_resort_node.py new file mode 100644 index 0000000..f43029e --- /dev/null +++ b/onnx_resort_node.py @@ -0,0 +1,76 @@ + +import onnx +from onnx import helper +from onnx import AttributeProto, TensorProto, GraphProto +import time +import sys + +def check_update_model(model): + if (model.graph.name == ''): + model.graph.name = 'tackle_checker_error' + + try: + onnx.checker.check_model(model) + except onnx.checker.ValidationError as e: + print("model is invalid: %s"%(e)) + # 目前只考虑这个错误“ Nodes in a graph must be topologically sorted, however input ...”, + # 就是模型的节点顺序有问题,onnxruntime 是可以推断的,为了其他引擎比如 trt可以推断,调整一下节点顺序。 + else: + time.sleep(1) + print("模型正常,返回原模型") + return model + + print('update model ...') + new_model = onnx.ModelProto(ir_version=model.ir_version, + producer_name=model.producer_name, + producer_version=model.producer_version, + opset_import=model.opset_import) + + new_model.graph.name = model.graph.name + # 只考虑模型是单输入单输出的情况。 + # 这里不需要 copy.deepcopy(model.graph.input[0]), python 是写时copy。\ + list_total_input_name = [] + for inp in model.graph.input: + new_model.graph.input.append(inp) + list_total_input_name.append(inp.name) + for oup in model.graph.output: + new_model.graph.output.append(oup) + + + for weight in model.graph.initializer: + new_model.graph.initializer.append(weight) + list_total_input_name.append(weight.name) + + list_nodes_invalid = [] + for node in model.graph.node: + node_invalid = False + for node_input_name in node.input: + if (list_total_input_name.count(node_input_name) == 0): + list_nodes_invalid.append(node) + node_invalid = True + break + if (node_invalid): + continue + for node_output_name in node.output: + list_total_input_name.append(node_output_name) + new_model.graph.node.append(node) + + for node in list_nodes_invalid: + for node_input_name in node.input: + assert list_total_input_name.count(node_input_name) > 0 + for node_output_name in node.output: + list_total_input_name.append(node_output_name) + new_model.graph.node.append(node) + + return new_model + + +if __name__ == "__main__": + onnx_in_name = sys.argv[1] + onnx_out_name = sys.argv[2] + + onnx_model = onnx.load(onnx_in_name) + new_model = check_update_model(onnx_model) + onnx.checker.check_model(new_model) + + onnx.save(new_model, onnx_out_name) \ No newline at end of file