From 38649877ebf8f755f0e1493c4ae0c3e2da3a8714 Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Sat, 13 Aug 2022 12:18:04 +0800 Subject: [PATCH] implement `adding new model output` feature, which is mentioned in https://github.com/ZhangGe6/onnx-modifier/issues/7, https://github.com/ZhangGe6/onnx-modifier/issues/8, and https://github.com/ZhangGe6/onnx-modifier/issues/13 --- onnx_modifier.py | 66 +++++++++++++++++++----------------------- static/index.js | 38 +++++++++++++++++++----- static/onnx.js | 13 ++++++++- static/view-grapher.js | 1 + static/view-sidebar.js | 11 +++++-- static/view.js | 36 ++++++++++++++++------- 6 files changed, 107 insertions(+), 58 deletions(-) diff --git a/onnx_modifier.py b/onnx_modifier.py index 2ffff32..bf42a59 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -109,11 +109,12 @@ class onnxModifier: # self.initializer.remove(self.initializer_name2module[init_name]) def modify_node_io_name(self, node_renamed_io): - # print(node_renamed_io) for node_name in node_renamed_io.keys(): + if node_name not in self.node_name2module.keys(): + # custom added nodes or custom added model outputs + continue renamed_ios = node_renamed_io[node_name] for src_name, dst_name in renamed_ios.items(): - # print(src_name, dst_name) node = self.node_name2module[node_name] if node_name in self.graph_input_names: node.name = dst_name @@ -149,16 +150,28 @@ class onnxModifier: self.graph.node.append(node) - + def add_outputs(self, added_outputs, node_states): + # https://github.com/onnx/onnx/issues/3277#issuecomment-1050600445 + added_output_names = added_outputs.values() + # filter out the deleted custom-added outputs + value_info_protos = [] + shape_info = onnx.shape_inference.infer_shapes(self.model_proto) + for value_info in shape_info.graph.value_info: + if value_info.name in added_output_names: + value_info_protos.append(value_info) + self.graph.output.extend(value_info_protos) + def modify(self, modify_info): # print(modify_info['node_states']) # print(modify_info['node_renamed_io']) # print(modify_info['node_changed_attr']) # print(modify_info['added_node_info']) + # print(modify_info['added_outputs']) self.remove_node_by_node_states(modify_info['node_states']) self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_attr(modify_info['node_changed_attr']) self.add_node(modify_info['added_node_info'], modify_info['node_states']) + self.add_outputs(modify_info['added_outputs'], modify_info['node_states']) def check_and_save_model(self, save_dir='./modified_onnx'): @@ -191,15 +204,15 @@ class onnxModifier: # This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506 out = inference_session.run(None, {input_name: x})[0] - # print(out) - + print(out.shape) if __name__ == "__main__": # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" - # model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx" - model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\with-added-output-modified_modified_squeezenet1.0-12.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" // TODO: this model is not supported well , but why? # model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) @@ -259,7 +272,6 @@ if __name__ == "__main__": onnx_modifier.check_and_save_model() # remove_node_by_node_states() - def test_modify_node_io_name(): node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}} @@ -285,34 +297,14 @@ if __name__ == "__main__": onnx_modifier.check_and_save_model() # test_change_node_attr() - - def debug_remove_node_by_node_states(): - # print(len(onnx_modifier.graph.node)) - # print(len(onnx_modifier.graph.initializer)) - - # print(onnx_modifier.node_name2module.keys()) - # print(onnx_modifier.graph.node) - # for node in onnx_modifier.graph.node: - # print(node.name) - # print(node.input) - # print(node.output) - - # print('\noriginal input') - # for inp in onnx_modifier.graph.input: - # print(inp.name) - - - node_states = {'data_0': 'Exist', 'Conv0': 'Exist', 'Relu1': 'Exist', 'MaxPool2': 'Exist', 'Conv3': 'Exist', 'Relu4': 'Exist', 'Conv5': 'Exist', 'Relu6': 'Exist', 'Conv7': 'Exist', 'Relu8': 'Exist', 'Concat9': 'Exist', 'Conv10': 'Exist', 'Relu11': 'Exist', 'Conv12': 'Exist', 'Relu13': 'Exist', 'Conv14': 'Exist', 'Relu15': 'Exist', 'Concat16': 'Exist', 'MaxPool17': 'Exist', 'Conv18': 'Exist', 'Relu19': 'Exist', 'Conv20': 'Exist', 'Relu21': 'Exist', 'Conv22': 'Exist', 'Relu23': 'Exist', 'Concat24': 'Exist', 'Conv25': 'Exist', 'Relu26': 'Exist', 'Conv27': 'Exist', 'Relu28': 'Exist', 'Conv29': 'Exist', 'Relu30': 'Exist', 'Concat31': 'Exist', 'MaxPool32': 'Exist', 'Conv33': 'Exist', 'Relu34': 'Exist', 'Conv35': 'Exist', 'Relu36': 'Exist', 'Conv37': 'Exist', 'Relu38': 'Exist', 'Concat39': 'Exist', 'Conv40': 'Exist', 'Relu41': 'Exist', 'Conv42': 'Exist', 'Relu43': 'Exist', 'Conv44': 'Exist', 'Relu45': 'Exist', 'Concat46': 'Exist', 'Conv47': 'Exist', 'Relu48': 'Exist', 'Conv49': 'Exist', 'Relu50': 'Exist', 'Conv51': 'Exist', 'Relu52': 'Exist', 'Concat53': 'Exist', 'Conv54': 'Exist', 'Relu55': 'Deleted', 'Conv56': 'Deleted', 'Relu57': 'Deleted', 'Conv58': 'Deleted', 'Relu59': 'Deleted', 'Concat60': 'Deleted', 'Dropout61': 'Deleted', 'Conv62': 'Deleted', 'Relu63': 'Deleted', 'GlobalAveragePool64': 'Deleted', 'Softmax65': 'Deleted', 'out_softmaxout_1': 'Deleted'} - # print('\graph input') - # for inp in onnx_modifier.graph.input: - # print(inp.name) - onnx_modifier.remove_node_by_node_states(node_states) - - print('\nleft input') - for inp in onnx_modifier.graph.input: - print(inp.name) - - onnx_modifier.check_and_save_model() - debug_remove_node_by_node_states() + def test_inference(): + onnx_modifier.inference() + test_inference() + def test_add_output(): + # print(onnx_modifier.graph.output) + onnx_modifier.add_outputs(['fire2/squeeze1x1_1']) + # print(onnx_modifier.graph.output) + onnx_modifier.check_and_save_model() + # test_add_output() \ No newline at end of file diff --git a/static/index.js b/static/index.js index 98be550..3031e19 100644 --- a/static/index.js +++ b/static/index.js @@ -229,10 +229,10 @@ host.BrowserHost = class { 'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State), 'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap), 'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes), - 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)) - } - - ) + 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)), + 'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs, + this._view._graph._renameMap, this._view._graph._modelNodeName2State)) + }) }).then(function (response) { return response.text(); }).then(function (text) { @@ -264,9 +264,6 @@ host.BrowserHost = class { this._view._updateGraph(); }) - - - this.document.getElementById('version').innerText = this.version; if (this._meta.file) { @@ -677,6 +674,33 @@ host.BrowserHost = class { } return lo } + + // this function does 2 things: + // 1. rename the addedOutputs with their new names using renameMap. Because addedOutputs are stored in lists, + // it may be not easy to rename them while editing. (Of course there may be a better way to do this) + // 2. filter out the custom output which is added, but deleted later + process_added_outputs(addedOutputs, renameMap, modelNodeName2State) { + var processed = [] + for (let i = 0; i < addedOutputs.length; ++i) { + if (modelNodeName2State.get("out_" + addedOutputs[i]) == "Exist") { + processed.push(addedOutputs[i]); + } + } + for (let i = 0; i < processed.length; ++i) { + if (renameMap.get("out_" + processed[i])) { + processed[i] = renameMap.get("out_" + processed[i]).get(processed[i]); + } + } + return processed; + } + + // https://stackoverflow.com/a/4215753/10096987 + arrayToObject(arr) { + var rv = {}; + for (var i = 0; i < arr.length; ++i) + if (arr[i] !== undefined) rv[i] = arr[i]; + return rv; + } // convert view.LightNodeInfo to Map object for easier transmission to Python backend parseLightNodeInfo2Map(nodes_info) { diff --git a/static/onnx.js b/static/onnx.js index 2ae0842..89eb18e 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -438,6 +438,7 @@ onnx.Graph = class { this._custom_add_node_io_idx = 0 this._custom_added_node = [] + this._custom_added_outputs = [] // model parameter assignment here! // console.log(graph) @@ -504,7 +505,8 @@ onnx.Graph = class { } get outputs() { - return this._outputs; + // return this._outputs; + return this._outputs.concat(this._custom_added_outputs); } get nodes() { @@ -632,6 +634,15 @@ onnx.Graph = class { return custom_add_node; } + reset_custom_added_outputs() { + this._custom_added_outputs = []; + } + + add_output(name) { + const argument = this._context.argument(name); + this._custom_added_outputs.push(new onnx.Parameter(name, [ argument ])); + } + }; onnx.Parameter = class { diff --git a/static/view-grapher.js b/static/view-grapher.js index 5e8f253..169a414 100644 --- a/static/view-grapher.js +++ b/static/view-grapher.js @@ -24,6 +24,7 @@ grapher.Graph = class { this._changedAttributes = new Map(); this._addedNode = new Map(); + this._addedOutputs = []; } get options() { diff --git a/static/view-sidebar.js b/static/view-sidebar.js index e89a1c5..71f7ffc 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -209,6 +209,9 @@ sidebar.NodeSidebar = class { this._addButton('Recover Node'); this.add_separator(this._elements, 'sidebar-view-separator') this._addButton('Enter'); + + this._addHeader('Output adding helper'); + this._addButton('Add Output'); // deprecated // this.add_separator(this._elements, 'sidebar-view-separator'); @@ -272,8 +275,6 @@ sidebar.NodeSidebar = class { } } } - - } render() { @@ -356,7 +357,11 @@ sidebar.NodeSidebar = class { this._host._view._updateGraph() }); } - + if (title === 'Add Output') { + buttonElement.addEventListener('click', () => { + this._host._view._graph.add_output(this._modelNodeName) + }); + } } // deprecated diff --git a/static/view.js b/static/view.js index aff8a92..1cdc2d2 100644 --- a/static/view.js +++ b/static/view.js @@ -464,7 +464,6 @@ view.View = class { this.refreshModelInputOutput() this.refreshNodeArguments() this.refreshNodeAttributes() - } return active_graph @@ -580,7 +579,8 @@ view.View = class { viewGraph._renameMap = this.lastViewGraph._renameMap; viewGraph._changedAttributes = this.lastViewGraph._changedAttributes; viewGraph._addedNode = this.lastViewGraph._addedNode; - viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey + viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey; + viewGraph._addedOutputs = this.lastViewGraph._addedOutputs; // console.log(viewGraph._renameMap); // console.log(viewGraph._modelNodeName2State) } @@ -866,7 +866,7 @@ view.View = class { } - // re-generate the added node according to _addedNode + // re-generate the added node according to _addedNode according to the latest _addedNode refreshAddedNode() { this._graphs[0].reset_custom_added_node() // for (const node_info of this._addedNode.values()) { @@ -880,8 +880,7 @@ view.View = class { for (const arg of input._arguments) { input_list_names.push(arg.name) } - this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names) - + this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names) } for (const output of node.outputs) { @@ -896,7 +895,7 @@ view.View = class { } // re-fresh node arguments in case the node inputs/outputs are changed - refreshNodeArguments() { + refreshNodeArguments() { for (var node of this._graphs[0]._nodes) { if (this.lastViewGraph._renameMap.get(node.modelNodeName)) { @@ -979,7 +978,13 @@ view.View = class { } } - for (var output of this._graphs[0]._outputs) { + // create and add new output to graph + this._graphs[0].reset_custom_added_outputs(); + for (var output_name of this.lastViewGraph._addedOutputs) { + this._graphs[0].add_output(output_name); + } + // console.log(this._graphs[0].outputs) + for (var output of this._graphs[0].outputs) { var output_orig_name = output.arguments[0].original_name if (this.lastViewGraph._renameMap.get('out_' + output_orig_name)) { // for model input and output, node.modelNodeName == element.original_name @@ -1011,6 +1016,7 @@ view.View = class { } } } + // console.log(this.lastViewGraph._renameMap) } } } @@ -1190,7 +1196,6 @@ view.Graph = class extends grapher.Graph { } } - for (const output of graph.outputs) { const viewOutput = this.createOutput(output); for (const argument of output.arguments) { @@ -1238,6 +1243,17 @@ view.Graph = class extends grapher.Graph { } } + add_output(node_name) { + var model_node = this._modelNodeName2ModelNode.get(node_name); + for (var output of model_node.outputs) { + for (var argument of output.arguments) { + this._addedOutputs.push(argument.name); + } + } + // console.log(this._addedOutputs); + this.view._updateGraph(); + } + resetGraph() { // reset node states for (const nodeId of this.nodes.keys()) { @@ -1277,14 +1293,14 @@ view.Graph = class extends grapher.Graph { } } - - } this._renameMap = new Map(); // clear custom added nodes this._addedNode = new Map() this.view._graphs[0].reset_custom_added_node() + this._addedOutputs = [] + this.view._graphs[0].reset_custom_added_outputs() } recordRenameInfo(modelNodeName, src_name, dst_name) {