import onnx from onnx import helper from onnx import AttributeProto, TensorProto, GraphProto import time import sys def check_update_model(model): if (model.graph.name == ''): model.graph.name = 'tackle_checker_error' try: onnx.checker.check_model(model) except onnx.checker.ValidationError as e: print("model is invalid: %s"%(e)) # 目前只考虑这个错误“ Nodes in a graph must be topologically sorted, however input ...”, # 就是模型的节点顺序有问题,onnxruntime 是可以推断的,为了其他引擎比如 trt可以推断,调整一下节点顺序。 else: time.sleep(1) print("模型正常,返回原模型") return model print('update model ...') new_model = onnx.ModelProto(ir_version=model.ir_version, producer_name=model.producer_name, producer_version=model.producer_version, opset_import=model.opset_import) new_model.graph.name = model.graph.name # 只考虑模型是单输入单输出的情况。 # 这里不需要 copy.deepcopy(model.graph.input[0]), python 是写时copy。\ list_total_input_name = [] for inp in model.graph.input: new_model.graph.input.append(inp) list_total_input_name.append(inp.name) for oup in model.graph.output: new_model.graph.output.append(oup) for weight in model.graph.initializer: new_model.graph.initializer.append(weight) list_total_input_name.append(weight.name) list_nodes_invalid = [] for node in model.graph.node: node_invalid = False for node_input_name in node.input: if (list_total_input_name.count(node_input_name) == 0): list_nodes_invalid.append(node) node_invalid = True break if (node_invalid): continue for node_output_name in node.output: list_total_input_name.append(node_output_name) new_model.graph.node.append(node) for node in list_nodes_invalid: for node_input_name in node.input: assert list_total_input_name.count(node_input_name) > 0 for node_output_name in node.output: list_total_input_name.append(node_output_name) new_model.graph.node.append(node) return new_model if __name__ == "__main__": onnx_in_name = sys.argv[1] onnx_out_name = sys.argv[2] onnx_model = onnx.load(onnx_in_name) new_model = check_update_model(onnx_model) onnx.checker.check_model(new_model) onnx.save(new_model, onnx_out_name)