From 046f342c2ed9cfe5102e0cfbb3e761e6ccd12f8b Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Thu, 28 Apr 2022 21:05:50 +0800 Subject: [PATCH] remove initializers basically done --- app.py | 2 +- onnx_modifier.py | 116 +++++++++++++++++++++++++++++------------ static/view-sidebar.js | 1 + 3 files changed, 84 insertions(+), 35 deletions(-) diff --git a/app.py b/app.py index bfa6f5e..466534d 100644 --- a/app.py +++ b/app.py @@ -22,7 +22,7 @@ def return_file(): def modify_and_download_model(): node_states = json.loads(request.get_json()) - # print(modelNodeStates) + print(node_states) onnx_modifier.remove_node_by_node_states(node_states) onnx_modifier.check_and_save_model() diff --git a/onnx_modifier.py b/onnx_modifier.py index 1341050..92dfd90 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -1,6 +1,7 @@ # https://leimao.github.io/blog/ONNX-Python-API/ # https://github.com/saurabh-shandilya/onnx-utils -# +# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model + import io import os import onnx @@ -10,16 +11,15 @@ class onnxModifier: self.model_name = model_name self.model_proto = model_proto self.graph = self.model_proto.graph + self.initializer = self.model_proto.graph.initializer - self.gen_node_name2module_map() + self.gen_name2module_map() - def gen_node_name2module_map(self): + + def gen_name2module_map(self): + # node name => node self.node_name2module = dict() node_idx = 0 - # for node in self.graph.input: - # node_idx += 1 - # self.node_name2module[node.name] = node - for node in self.graph.node: if node.name == '': node.name = str(node.op_type) + str(node_idx) @@ -31,6 +31,11 @@ class onnxModifier: self.graph_output_names = [out.name for out in self.graph.output] # print(self.node_name2module.keys()) + # initializer name => initializer + self.initilizer_name2module = dict() + for initializer in self.initializer: + self.initilizer_name2module[initializer.name] = initializer + @classmethod def from_model_path(cls, model_path): model_name = os.path.basename(model_path) @@ -45,12 +50,14 @@ class onnxModifier: return cls(name, model_proto) def remove_node_by_name(self, node_name): - self.graph.node.remove(self.node_name2module[node_name]) + # remove node in graph + self.graph.node.remove(self.node_name2module[node_name]) def remove_output_by_name(self, node_name): self.graph.output.remove(self.node_name2module[node_name]) def remove_node_by_node_states(self, node_states): + # remove node from graph for node_name, node_state in node_states.items(): if node_state == 'Deleted': if node_name in self.graph_output_names: @@ -59,7 +66,17 @@ class onnxModifier: else: # print('removing node {} ...'.format(node_name)) self.remove_node_by_name(node_name) - + + # remove node initializers (parameters) + # aka, keep and only keep the initializers of left nodes + left_node_inputs = [] + for left_node in self.graph.node: + left_node_inputs += left_node.input + + for init_name in self.initilizer_name2module.keys(): + if not init_name in left_node_inputs: + self.initializer.remove(self.initilizer_name2module[init_name]) + def check_and_save_model(self, save_dir='./res_onnx'): save_path = os.path.join(save_dir, 'modified_' + self.model_name) onnx.checker.check_model(self.model_proto) @@ -75,31 +92,62 @@ if __name__ == "__main__": model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) - # for node in onnx_modifier.graph.node: - # print(node.name) - # for node in onnx_modifier.graph.output: - # print(node.name) - print(onnx_modifier.node_name2module.keys()) - print(onnx_modifier.graph_output_names) - - # onnx_modifier.remove_node_by_name('Softmax_nc_rename_64') - # onnx_modifier.remove_output_by_name('softmaxout_1') - # onnx_modifier.graph.output.remove(onnx_modifier.node_name2module['softmaxout_1']) - # onnx_modifier.check_and_save_model() - - # print(type(onnx_modifier.graph.input)) - # print(type(onnx_modifier.graph.output)) - # print(onnx_modifier.graph.input) - # print(onnx_modifier.graph.output) - # print(onnx_modifier.node_name2module['Softmax_nc_rename_64']) - # print(onnx_modifier.node_name2module['softmaxout_1']) - # onnx_modifier.remove_node_by_name('softmaxout_1') - - - # for node in onnx_modifier.graph.output: - # print(node.name) + def remove_node_by_node_states(): + print(len(onnx_modifier.graph.node)) + print(len(onnx_modifier.graph.initializer)) + node_states_fp = {'data_0': 'Exist', 'Conv0': 'Exist', 'Relu1': 'Exist', 'MaxPool2': 'Exist', 'Conv3': 'Exist', 'Relu4': 'Exist', 'Conv5': 'Exist', 'Relu6': 'Exist', 'Conv7': 'Deleted', 'Relu8': 'Deleted', 'Concat9': 'Deleted', 'Conv10': 'Deleted', 'Relu11': 'Deleted', 'Conv12': 'Deleted', 'Relu13': 'Deleted', 'Conv14': 'Deleted', 'Relu15': 'Deleted', 'Concat16': 'Deleted', 'MaxPool17': 'Deleted', 'Conv18': 'Deleted', 'Relu19': 'Deleted', 'Conv20': 'Deleted', 'Relu21': 'Deleted', 'Conv22': 'Deleted', 'Relu23': 'Deleted', 'Concat24': 'Deleted', 'Conv25': 'Deleted', 'Relu26': 'Deleted', 'Conv27': 'Deleted', 'Relu28': 'Deleted', 'Conv29': 'Deleted', 'Relu30': 'Deleted', 'Concat31': 'Deleted', 'MaxPool32': 'Deleted', 'Conv33': 'Deleted', 'Relu34': 'Deleted', 'Conv35': 'Deleted', 'Relu36': 'Deleted', 'Conv37': 'Deleted', 'Relu38': 'Deleted', 'Concat39': 'Deleted', 'Conv40': 'Deleted', 'Relu41': 'Deleted', 'Conv42': 'Deleted', 'Relu43': 'Deleted', 'Conv44': 'Deleted', 'Relu45': 'Deleted', 'Concat46': 'Deleted', 'Conv47': 'Deleted', 'Relu48': 'Deleted', 'Conv49': 'Deleted', 'Relu50': 'Deleted', 'Conv51': 'Deleted', 'Relu52': 'Deleted', 'Concat53': 'Deleted', 'Conv54': 'Deleted', 'Relu55': 'Deleted', 'Conv56': 'Deleted', 'Relu57': 'Deleted', 'Conv58': 'Deleted', 'Relu59': 'Deleted', 'Concat60': 'Deleted', 'Dropout61': 'Deleted', 'Conv62': 'Deleted', 'Relu63': 'Deleted', 'GlobalAveragePool64': 'Deleted', 'Softmax65': 'Deleted', 'softmaxout_1': 'Deleted'} + + node_states_quant = {'data_0': 'Exist', 'data_0_QuantizeLinear': 'Exist', 'Conv_nc_rename_0_quant': 'Exist', 'MaxPool_nc_rename_2_quant': 'Exist', 'Conv_nc_rename_3_quant': 'Deleted', 'Conv_nc_rename_5_quant': 'Deleted', 'Conv_nc_rename_7_quant': 'Deleted', 'fire2/expand1x1_2_DequantizeLinear': 'Deleted', 'fire2/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_9': 'Deleted', 'fire2/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_10_quant': 'Deleted', 'Conv_nc_rename_12_quant': 'Deleted', 'Conv_nc_rename_14_quant': 'Deleted', 'fire3/expand1x1_2_DequantizeLinear': 'Deleted', + 'fire3/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_16': 'Deleted', 'MaxPool_nc_rename_17': 'Deleted', 'pool3_1_QuantizeLinear': + 'Deleted', 'Conv_nc_rename_18_quant': 'Deleted', 'Conv_nc_rename_20_quant': 'Deleted', 'Conv_nc_rename_22_quant': 'Deleted', 'fire4/expand1x1_2_DequantizeLinear': 'Deleted', 'fire4/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_24': 'Deleted', 'fire4/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_25_quant': 'Deleted', 'Conv_nc_rename_27_quant': 'Deleted', 'Conv_nc_rename_29_quant': 'Deleted', 'fire5/expand1x1_2_DequantizeLinear': 'Deleted', 'fire5/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_31': 'Deleted', 'MaxPool_nc_rename_32': 'Deleted', 'pool5_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_33_quant': 'Deleted', 'Conv_nc_rename_35_quant': 'Deleted', 'Conv_nc_rename_37_quant': 'Deleted', 'fire6/expand1x1_2_DequantizeLinear': 'Deleted', 'fire6/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_39': 'Deleted', 'fire6/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_40_quant': 'Deleted', 'Conv_nc_rename_42_quant': 'Deleted', 'Conv_nc_rename_44_quant': 'Deleted', 'fire7/expand1x1_2_DequantizeLinear': 'Deleted', 'fire7/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_46': 'Deleted', 'fire7/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_47_quant': 'Deleted', 'Conv_nc_rename_49_quant': 'Deleted', 'Conv_nc_rename_51_quant': 'Deleted', 'fire8/expand1x1_2_DequantizeLinear': 'Deleted', 'fire8/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_53': 'Deleted', 'fire8/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_54_quant': 'Deleted', 'Conv_nc_rename_56_quant': 'Deleted', 'Conv_nc_rename_58_quant': 'Deleted', 'fire9/expand1x1_2_DequantizeLinear': 'Deleted', 'fire9/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_60': 'Deleted', 'fire9/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_61_quant': 'Deleted', 'GlobalAveragePool_nc_rename_63_quant': 'Deleted', 'pool10_1_DequantizeLinear': 'Deleted', 'Softmax_nc_rename_64': 'Deleted', 'softmaxout_1': 'Deleted'} + + # node_states = node_states_quant + node_states = node_states_fp + # print('\graph input') + # for inp in onnx_modifier.graph.input: + # print(inp.name) + onnx_modifier.remove_node_by_node_states(node_states) + print(len(onnx_modifier.graph.node)) + print(len(onnx_modifier.graph.initializer)) + print(len(onnx_modifier.initilizer_name2module.keys())) + # print(onnx_modifier.initilizer_name2module.keys()) + # for i, k in enumerate(onnx_modifier.initilizer_name2module.keys()): + # print("\nremoving", i, k) + # onnx_modifier.graph.initializer.remove(onnx_modifier.initilizer_name2module[k]) + # print("removed") + + print('\nleft initializers:') + for initializer in onnx_modifier.model_proto.graph.initializer: + print(initializer.name) + + print('\nleft nodes:') + for node in onnx_modifier.graph.node: + print(node.name) + + print('\nleft input') + for inp in onnx_modifier.graph.input: + print(inp.name) + + onnx_modifier.check_and_save_model() + remove_node_by_node_states() - + def explore_basic(): + print(type(onnx_modifier.model_proto.graph.initializer)) + print(dir(onnx_modifier.model_proto.graph.initializer)) + + print(len(onnx_modifier.model_proto.graph.node)) + print(len(onnx_modifier.model_proto.graph.initializer)) + + # for node in onnx_modifier.model_proto.graph.node: + # print(node.name) + # print(node.input) + # print() + + # for initializer in onnx_modifier.model_proto.graph.initializer: + # print(initializer.name) + # print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale']) + + # explore_basic() - \ No newline at end of file + \ No newline at end of file diff --git a/static/view-sidebar.js b/static/view-sidebar.js index db80a6d..35216a0 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -215,6 +215,7 @@ sidebar.NodeSidebar = class { add_span(className) { const span = this._host.document.createElement('span'); + span.innerHTML = " "; // (if this doesn't work, try " ") span.className = className; this._elements.push(span); }