e6ec是最后一版可以使用的,在onnx_modifier.py中必须有import parse_tensor才行,不然下载不了修改的模型,windows上的exe没有试不知道,再增加一个resort脚本,必须过一遍,不然会出错

1123
wangchunlin 3 years ago
parent e6ec898f41
commit 76c7de776a

@ -0,0 +1,76 @@
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
import time
import sys
def check_update_model(model):
if (model.graph.name == ''):
model.graph.name = 'tackle_checker_error'
try:
onnx.checker.check_model(model)
except onnx.checker.ValidationError as e:
print("model is invalid: %s"%(e))
# 目前只考虑这个错误“ Nodes in a graph must be topologically sorted, however input ...”,
# 就是模型的节点顺序有问题onnxruntime 是可以推断的,为了其他引擎比如 trt可以推断调整一下节点顺序。
else:
time.sleep(1)
print("模型正常,返回原模型")
return model
print('update model ...')
new_model = onnx.ModelProto(ir_version=model.ir_version,
producer_name=model.producer_name,
producer_version=model.producer_version,
opset_import=model.opset_import)
new_model.graph.name = model.graph.name
# 只考虑模型是单输入单输出的情况。
# 这里不需要 copy.deepcopy(model.graph.input[0]), python 是写时copy。\
list_total_input_name = []
for inp in model.graph.input:
new_model.graph.input.append(inp)
list_total_input_name.append(inp.name)
for oup in model.graph.output:
new_model.graph.output.append(oup)
for weight in model.graph.initializer:
new_model.graph.initializer.append(weight)
list_total_input_name.append(weight.name)
list_nodes_invalid = []
for node in model.graph.node:
node_invalid = False
for node_input_name in node.input:
if (list_total_input_name.count(node_input_name) == 0):
list_nodes_invalid.append(node)
node_invalid = True
break
if (node_invalid):
continue
for node_output_name in node.output:
list_total_input_name.append(node_output_name)
new_model.graph.node.append(node)
for node in list_nodes_invalid:
for node_input_name in node.input:
assert list_total_input_name.count(node_input_name) > 0
for node_output_name in node.output:
list_total_input_name.append(node_output_name)
new_model.graph.node.append(node)
return new_model
if __name__ == "__main__":
onnx_in_name = sys.argv[1]
onnx_out_name = sys.argv[2]
onnx_model = onnx.load(onnx_in_name)
new_model = check_update_model(onnx_model)
onnx.checker.check_model(new_model)
onnx.save(new_model, onnx_out_name)
Loading…
Cancel
Save