diff --git a/onnx_modifier.py b/onnx_modifier.py index 20228c1..c362cf2 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -52,9 +52,9 @@ class onnxModifier: # print(self.node_name2module.keys()) # initializer name => initializer - self.initilizer_name2module = dict() + self.initializer_name2module = dict() for initializer in self.initializer: - self.initilizer_name2module[initializer.name] = initializer + self.initializer_name2module[initializer.name] = initializer def remove_node_by_name(self, node_name): # remove node in graph @@ -82,9 +82,23 @@ class onnxModifier: for left_node in self.graph.node: left_node_inputs += left_node.input - for init_name in self.initilizer_name2module.keys(): + for init_name in self.initializer_name2module.keys(): if not init_name in left_node_inputs: - self.initializer.remove(self.initilizer_name2module[init_name]) + self.initializer.remove(self.initializer_name2module[init_name]) + + # remove the left unused Constant nodes + for left_node in self.graph.node: + if left_node.op_type == "Constant": + output_deleted = [False] * len(left_node.output) + for i, output in enumerate(left_node.output): + if not (output in left_node_inputs): + output_deleted[i] = True + + const_node_left_output = [left_node.output[i] for i in range(len(left_node.output)) if not output_deleted[i]] + if len(const_node_left_output) == 0: + self.graph.node.remove(self.node_name2module[left_node.name]) + # self.initializer.remove(self.initializer_name2module[init_name]) + def modify_node_io_name(self, node_renamed_io): # print(node_renamed_io) @@ -113,8 +127,9 @@ class onnxModifier: def modify(self, modify_info): + # print(modify_info['node_states']) # print(modify_info['node_renamed_io']) - print(modify_info['added_node_info']) + # print(modify_info['added_node_info']) self.remove_node_by_node_states(modify_info['node_states']) self.modify_node_io_name(modify_info['node_renamed_io']) self.add_node(modify_info['added_node_info'], modify_info['node_states']) @@ -155,7 +170,9 @@ if __name__ == "__main__": # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" - model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_mobilenetv2-7.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) def explore_basic(): @@ -177,40 +194,52 @@ if __name__ == "__main__": # explore_basic() def remove_node_by_node_states(): - print(len(onnx_modifier.graph.node)) - print(len(onnx_modifier.graph.initializer)) - node_states_fp = {} - node_states_quant = {} + # print(len(onnx_modifier.graph.node)) + # print(len(onnx_modifier.graph.initializer)) - node_states = node_states_quant - # node_states = node_states_fp + # print(onnx_modifier.node_name2module.keys()) + # print(onnx_modifier.graph.node) + # for node in onnx_modifier.graph.node: + # print(node.name) + # print(node.input) + # print(node.output) + + + node_states = {'input': 'Exist', 'Conv_0': 'Exist', 'Conv_95': 'Exist', 'Clip_96': 'Deleted', 'GlobalAveragePool_97': 'Deleted', 'Shape_98': 'Deleted', 'Gather_100': 'Deleted', 'Unsqueeze_101': 'Deleted', 'Concat_102': 'Deleted', 'Reshape_103': 'Deleted', 'Gemm_104': 'Deleted', 'out_output': 'Deleted'} # print('\graph input') # for inp in onnx_modifier.graph.input: # print(inp.name) onnx_modifier.remove_node_by_node_states(node_states) - print(len(onnx_modifier.graph.node)) - print(len(onnx_modifier.graph.initializer)) - print(len(onnx_modifier.initilizer_name2module.keys())) - # print(onnx_modifier.initilizer_name2module.keys()) - # for i, k in enumerate(onnx_modifier.initilizer_name2module.keys()): + # print(len(onnx_modifier.graph.node)) + # print(len(onnx_modifier.graph.initializer)) + # print(len(onnx_modifier.initializer_name2module.keys())) + + for node in onnx_modifier.graph.node: + print(node.name) + print(node.input, node.output) + for initializer in onnx_modifier.initializer: + print(initializer.name) + + # print(onnx_modifier.initializer_name2module.keys()) + # for i, k in enumerate(onnx_modifier.initializer_name2module.keys()): # print("\nremoving", i, k) - # onnx_modifier.graph.initializer.remove(onnx_modifier.initilizer_name2module[k]) + # onnx_modifier.graph.initializer.remove(onnx_modifier.initializer_name2module[k]) # print("removed") - print('\nleft initializers:') - for initializer in onnx_modifier.model_proto.graph.initializer: - print(initializer.name) + # print('\nleft initializers:') + # for initializer in onnx_modifier.model_proto.graph.initializer: + # print(initializer.name) - print('\nleft nodes:') - for node in onnx_modifier.graph.node: - print(node.name) + # print('\nleft nodes:') + # for node in onnx_modifier.graph.node: + # print(node.name) - print('\nleft input') - for inp in onnx_modifier.graph.input: - print(inp.name) + # print('\nleft input') + # for inp in onnx_modifier.graph.input: + # print(inp.name) onnx_modifier.check_and_save_model() - # remove_node_by_node_states() + remove_node_by_node_states() def test_modify_node_io_name(): @@ -227,6 +256,6 @@ if __name__ == "__main__": onnx_modifier.inference() onnx_modifier.check_and_save_model() - test_add_node() + # test_add_node() \ No newline at end of file