|
|
|
|
@ -10,6 +10,7 @@ class onnxModifier:
|
|
|
|
|
def __init__(self, model_name, model_proto):
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.model_proto_backup = model_proto
|
|
|
|
|
self.reload()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_path(cls, model_path):
|
|
|
|
|
@ -78,12 +79,29 @@ class onnxModifier:
|
|
|
|
|
for init_name in self.initilizer_name2module.keys():
|
|
|
|
|
if not init_name in left_node_inputs:
|
|
|
|
|
self.initializer.remove(self.initilizer_name2module[init_name])
|
|
|
|
|
|
|
|
|
|
def modify_node_io_name(self, node_renamed_io):
|
|
|
|
|
# print(node_renamed_io)
|
|
|
|
|
for node_name in node_renamed_io.keys():
|
|
|
|
|
renamed_ios = node_renamed_io[node_name]
|
|
|
|
|
for src_name, dst_name in renamed_ios.items():
|
|
|
|
|
# print(src_name, dst_name)
|
|
|
|
|
node = self.node_name2module[node_name]
|
|
|
|
|
# print(node.input, node.output)
|
|
|
|
|
for i in range(len(node.input)):
|
|
|
|
|
if node.input[i] == src_name:
|
|
|
|
|
node.input[i] = dst_name
|
|
|
|
|
for i in range(len(node.output)):
|
|
|
|
|
if node.output[i] == src_name:
|
|
|
|
|
node.output[i] = dst_name
|
|
|
|
|
# print(node.input, node.output)
|
|
|
|
|
|
|
|
|
|
def check_and_save_model(self, save_dir='./res_onnx'):
|
|
|
|
|
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
|
|
|
|
|
# onnx.checker.check_model(self.model_proto)
|
|
|
|
|
onnx.checker.check_model(self.model_proto)
|
|
|
|
|
onnx.save(self.model_proto, save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(self):
|
|
|
|
|
# model_proto_bytes = onnx._serialize(model_proto_from_stream)
|
|
|
|
|
# inference_session = rt.InferenceSession(model_proto_bytes)
|
|
|
|
|
@ -91,9 +109,9 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
|
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
|
|
|
|
|
|
def remove_node_by_node_states():
|
|
|
|
|
@ -134,7 +152,7 @@ if __name__ == "__main__":
|
|
|
|
|
print(inp.name)
|
|
|
|
|
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
remove_node_by_node_states()
|
|
|
|
|
# remove_node_by_node_states()
|
|
|
|
|
|
|
|
|
|
def explore_basic():
|
|
|
|
|
print(type(onnx_modifier.model_proto.graph.initializer))
|
|
|
|
|
@ -143,15 +161,20 @@ if __name__ == "__main__":
|
|
|
|
|
print(len(onnx_modifier.model_proto.graph.node))
|
|
|
|
|
print(len(onnx_modifier.model_proto.graph.initializer))
|
|
|
|
|
|
|
|
|
|
# for node in onnx_modifier.model_proto.graph.node:
|
|
|
|
|
# print(node.name)
|
|
|
|
|
# print(node.input)
|
|
|
|
|
# print()
|
|
|
|
|
for node in onnx_modifier.model_proto.graph.node:
|
|
|
|
|
print(node.name)
|
|
|
|
|
print(node.input)
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
# for initializer in onnx_modifier.model_proto.graph.initializer:
|
|
|
|
|
# print(initializer.name)
|
|
|
|
|
# print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale'])
|
|
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
# explore_basic()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_node_io_name():
|
|
|
|
|
node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}}
|
|
|
|
|
onnx_modifier.modify_node_io_name(node_rename_io)
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
test_modify_node_io_name()
|
|
|
|
|
|