|
|
|
|
@ -52,9 +52,9 @@ class onnxModifier:
|
|
|
|
|
# print(self.node_name2module.keys())
|
|
|
|
|
|
|
|
|
|
# initializer name => initializer
|
|
|
|
|
self.initilizer_name2module = dict()
|
|
|
|
|
self.initializer_name2module = dict()
|
|
|
|
|
for initializer in self.initializer:
|
|
|
|
|
self.initilizer_name2module[initializer.name] = initializer
|
|
|
|
|
self.initializer_name2module[initializer.name] = initializer
|
|
|
|
|
|
|
|
|
|
def remove_node_by_name(self, node_name):
|
|
|
|
|
# remove node in graph
|
|
|
|
|
@ -82,9 +82,23 @@ class onnxModifier:
|
|
|
|
|
for left_node in self.graph.node:
|
|
|
|
|
left_node_inputs += left_node.input
|
|
|
|
|
|
|
|
|
|
for init_name in self.initilizer_name2module.keys():
|
|
|
|
|
for init_name in self.initializer_name2module.keys():
|
|
|
|
|
if not init_name in left_node_inputs:
|
|
|
|
|
self.initializer.remove(self.initilizer_name2module[init_name])
|
|
|
|
|
self.initializer.remove(self.initializer_name2module[init_name])
|
|
|
|
|
|
|
|
|
|
# remove the left unused Constant nodes
|
|
|
|
|
for left_node in self.graph.node:
|
|
|
|
|
if left_node.op_type == "Constant":
|
|
|
|
|
output_deleted = [False] * len(left_node.output)
|
|
|
|
|
for i, output in enumerate(left_node.output):
|
|
|
|
|
if not (output in left_node_inputs):
|
|
|
|
|
output_deleted[i] = True
|
|
|
|
|
|
|
|
|
|
const_node_left_output = [left_node.output[i] for i in range(len(left_node.output)) if not output_deleted[i]]
|
|
|
|
|
if len(const_node_left_output) == 0:
|
|
|
|
|
self.graph.node.remove(self.node_name2module[left_node.name])
|
|
|
|
|
# self.initializer.remove(self.initializer_name2module[init_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def modify_node_io_name(self, node_renamed_io):
|
|
|
|
|
# print(node_renamed_io)
|
|
|
|
|
@ -113,8 +127,9 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def modify(self, modify_info):
|
|
|
|
|
# print(modify_info['node_states'])
|
|
|
|
|
# print(modify_info['node_renamed_io'])
|
|
|
|
|
print(modify_info['added_node_info'])
|
|
|
|
|
# print(modify_info['added_node_info'])
|
|
|
|
|
self.remove_node_by_node_states(modify_info['node_states'])
|
|
|
|
|
self.modify_node_io_name(modify_info['node_renamed_io'])
|
|
|
|
|
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
|
|
|
|
|
@ -155,7 +170,9 @@ if __name__ == "__main__":
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx"
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_mobilenetv2-7.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx"
|
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
|
|
|
|
|
|
def explore_basic():
|
|
|
|
|
@ -177,40 +194,52 @@ if __name__ == "__main__":
|
|
|
|
|
# explore_basic()
|
|
|
|
|
|
|
|
|
|
def remove_node_by_node_states():
|
|
|
|
|
print(len(onnx_modifier.graph.node))
|
|
|
|
|
print(len(onnx_modifier.graph.initializer))
|
|
|
|
|
node_states_fp = {}
|
|
|
|
|
node_states_quant = {}
|
|
|
|
|
# print(len(onnx_modifier.graph.node))
|
|
|
|
|
# print(len(onnx_modifier.graph.initializer))
|
|
|
|
|
|
|
|
|
|
node_states = node_states_quant
|
|
|
|
|
# node_states = node_states_fp
|
|
|
|
|
# print(onnx_modifier.node_name2module.keys())
|
|
|
|
|
# print(onnx_modifier.graph.node)
|
|
|
|
|
# for node in onnx_modifier.graph.node:
|
|
|
|
|
# print(node.name)
|
|
|
|
|
# print(node.input)
|
|
|
|
|
# print(node.output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_states = {'input': 'Exist', 'Conv_0': 'Exist', 'Conv_95': 'Exist', 'Clip_96': 'Deleted', 'GlobalAveragePool_97': 'Deleted', 'Shape_98': 'Deleted', 'Gather_100': 'Deleted', 'Unsqueeze_101': 'Deleted', 'Concat_102': 'Deleted', 'Reshape_103': 'Deleted', 'Gemm_104': 'Deleted', 'out_output': 'Deleted'}
|
|
|
|
|
# print('\graph input')
|
|
|
|
|
# for inp in onnx_modifier.graph.input:
|
|
|
|
|
# print(inp.name)
|
|
|
|
|
onnx_modifier.remove_node_by_node_states(node_states)
|
|
|
|
|
print(len(onnx_modifier.graph.node))
|
|
|
|
|
print(len(onnx_modifier.graph.initializer))
|
|
|
|
|
print(len(onnx_modifier.initilizer_name2module.keys()))
|
|
|
|
|
# print(onnx_modifier.initilizer_name2module.keys())
|
|
|
|
|
# for i, k in enumerate(onnx_modifier.initilizer_name2module.keys()):
|
|
|
|
|
# print(len(onnx_modifier.graph.node))
|
|
|
|
|
# print(len(onnx_modifier.graph.initializer))
|
|
|
|
|
# print(len(onnx_modifier.initializer_name2module.keys()))
|
|
|
|
|
|
|
|
|
|
for node in onnx_modifier.graph.node:
|
|
|
|
|
print(node.name)
|
|
|
|
|
print(node.input, node.output)
|
|
|
|
|
for initializer in onnx_modifier.initializer:
|
|
|
|
|
print(initializer.name)
|
|
|
|
|
|
|
|
|
|
# print(onnx_modifier.initializer_name2module.keys())
|
|
|
|
|
# for i, k in enumerate(onnx_modifier.initializer_name2module.keys()):
|
|
|
|
|
# print("\nremoving", i, k)
|
|
|
|
|
# onnx_modifier.graph.initializer.remove(onnx_modifier.initilizer_name2module[k])
|
|
|
|
|
# onnx_modifier.graph.initializer.remove(onnx_modifier.initializer_name2module[k])
|
|
|
|
|
# print("removed")
|
|
|
|
|
|
|
|
|
|
print('\nleft initializers:')
|
|
|
|
|
for initializer in onnx_modifier.model_proto.graph.initializer:
|
|
|
|
|
print(initializer.name)
|
|
|
|
|
# print('\nleft initializers:')
|
|
|
|
|
# for initializer in onnx_modifier.model_proto.graph.initializer:
|
|
|
|
|
# print(initializer.name)
|
|
|
|
|
|
|
|
|
|
print('\nleft nodes:')
|
|
|
|
|
for node in onnx_modifier.graph.node:
|
|
|
|
|
print(node.name)
|
|
|
|
|
# print('\nleft nodes:')
|
|
|
|
|
# for node in onnx_modifier.graph.node:
|
|
|
|
|
# print(node.name)
|
|
|
|
|
|
|
|
|
|
print('\nleft input')
|
|
|
|
|
for inp in onnx_modifier.graph.input:
|
|
|
|
|
print(inp.name)
|
|
|
|
|
# print('\nleft input')
|
|
|
|
|
# for inp in onnx_modifier.graph.input:
|
|
|
|
|
# print(inp.name)
|
|
|
|
|
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
# remove_node_by_node_states()
|
|
|
|
|
remove_node_by_node_states()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_node_io_name():
|
|
|
|
|
@ -227,6 +256,6 @@ if __name__ == "__main__":
|
|
|
|
|
onnx_modifier.inference()
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
|
|
|
|
|
test_add_node()
|
|
|
|
|
# test_add_node()
|
|
|
|
|
|
|
|
|
|
|