From 975682eafe221fc7ec4b80da0722d00056b20d90 Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Wed, 8 Jun 2022 22:44:08 +0800 Subject: [PATCH] `add node` feature works for some simple node like Abs --- app.py | 6 +++--- onnx_modifier.py | 47 +++++++++++++++++++++++++++++++++++++++--- static/index.js | 21 ++++++++++++++++++- static/onnx.js | 2 +- static/view-sidebar.js | 2 ++ static/view.js | 44 ++++++++++----------------------------- 6 files changed, 81 insertions(+), 41 deletions(-) diff --git a/app.py b/app.py index 1bbedc4..04070b7 100644 --- a/app.py +++ b/app.py @@ -22,9 +22,9 @@ def modify_and_download_model(): modify_info = request.get_json() # print(modify_info) - onnx_modifier.reload() # allow for downloading for multiple times - onnx_modifier.remove_node_by_node_states(modify_info['node_states']) - onnx_modifier.modify_node_io_name(modify_info['node_renamed_io']) + onnx_modifier.reload() # allow downloading for multiple times + + onnx_modifier.modify(modify_info) onnx_modifier.check_and_save_model() return 'OK', 200 diff --git a/onnx_modifier.py b/onnx_modifier.py index c23a26f..3c5d2eb 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -94,8 +94,38 @@ class onnxModifier: for i in range(len(node.output)): if node.output[i] == src_name: node.output[i] = dst_name - # print(node.input, node.output) + + + def add_node(self, nodes_info): + for node_info in nodes_info.values(): + name = node_info['properties']['name'] + op_type = node_info['properties']['op_type'] + attributes = node_info['attributes'] + + inputs = [] + for key in node_info['inputs'].keys(): + inputs += node_info['inputs'][key] + outputs = [] + for key in node_info['outputs'].keys(): + outputs += node_info['outputs'][key] + node = onnx.helper.make_node( + op_type=op_type, + inputs=inputs, + outputs=outputs, + name=name, + **attributes + ) + # print(node) + + self.graph.node.append(node) + + + def modify(self, modify_info): + self.remove_node_by_node_states(modify_info['node_states']) + self.modify_node_io_name(modify_info['node_renamed_io']) + self.add_node(modify_info['added_node_info']) + def check_and_save_model(self, save_dir='./modified_onnx'): if not os.path.exists(save_dir): os.mkdir(save_dir) @@ -111,9 +141,10 @@ class onnxModifier: if __name__ == "__main__": - model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) def remove_node_by_node_states(): @@ -167,11 +198,21 @@ if __name__ == "__main__": # for initializer in onnx_modifier.model_proto.graph.initializer: # print(initializer.name) # print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale']) + pass # explore_basic() def test_modify_node_io_name(): node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}} onnx_modifier.modify_node_io_name(node_rename_io) onnx_modifier.check_and_save_model() - test_modify_node_io_name() + # test_modify_node_io_name() + + def test_add_node(): + node_info = {'properties': {'domain': 'ai.onnx', 'op_type': 'Abs', 'name': 'custom_added_Abs0'}, 'attributes': {}, 'inputs': {'X': ['custom_input_0']}, 'outputs': {'Y': ['custom_output_1']}} + + onnx_modifier.add_node(node_info) + onnx_modifier.check_and_save_model() + + test_add_node() + \ No newline at end of file diff --git a/static/index.js b/static/index.js index 711d997..a84a834 100644 --- a/static/index.js +++ b/static/index.js @@ -214,7 +214,9 @@ host.BrowserHost = class { const downloadButton = this.document.getElementById('download-graph'); downloadButton.addEventListener('click', () => { - // console.log(this) + + console.log(this._view._graph._addedNode) + console.log(this._view._graph._renameMap) // https://healeycodes.com/talking-between-languages fetch('/download', { // Declare what type of data we're sending @@ -226,6 +228,7 @@ host.BrowserHost = class { body: JSON.stringify({ 'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State), 'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap), + 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)) } ) @@ -640,6 +643,22 @@ host.BrowserHost = class { } return lo } + + // convert view.LightNodeInfo to Map object for easier transmission to Python backend + parseLightNodeInfo2Map(nodes_info) { + var res_map = new Map() + for (const [modelNodeName, node_info] of nodes_info) { + var node_info_map = new Map() + node_info_map.set('properties', node_info.properties) + node_info_map.set('attributes', node_info.attributes) + node_info_map.set('inputs', node_info.inputs) + node_info_map.set('outputs', node_info.outputs) + + res_map.set(modelNodeName, node_info_map) + } + + return res_map + } }; host.Dropdown = class { diff --git a/static/onnx.js b/static/onnx.js index 3743d71..022491c 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -540,7 +540,7 @@ onnx.Graph = class { const input = schema.inputs[i] var node_info_input = node_info.inputs.get(input.name) - console.log(node_info_input) + // console.log(node_info_input) var arg_list = [] if (input.list) { diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 10ee072..deffe5c 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -682,6 +682,7 @@ class NodeAttributeView { var attr_input = document.createElement("INPUT"); attr_input.setAttribute("type", "text"); + attr_input.setAttribute("size", "42"); attr_input.setAttribute("value", content ? content : 'undefined'); attr_input.addEventListener('input', (e) => { // console.log(e.target.value); @@ -855,6 +856,7 @@ sidebar.ArgumentView = class { var arg_input = document.createElement("INPUT"); arg_input.setAttribute("type", "text"); + arg_input.setAttribute("size", "42"); arg_input.setAttribute("value", name); arg_input.addEventListener('input', (e) => { // console.log(this._argument) diff --git a/static/view.js b/static/view.js index 8a4d57a..58eb634 100644 --- a/static/view.js +++ b/static/view.js @@ -875,21 +875,21 @@ view.View = class { var node = this._graphs[0].make_custom_add_node(node_info) // console.log(node) - // padding empty array for LightNodeInfo.inputs/outputs (only when initializing) - if (this.lastViewGraph._addedNode.get(modelNodeName).inputs.size == 0) { - - for (var input of node.inputs) { - var arg_len = input._arguments.length - this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, new Array(arg_len)) + for (const input of node.inputs) { + var input_list_names = [] + for (const arg of input._arguments) { + input_list_names.push(arg.name) } + this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names) + } - if (this.lastViewGraph._addedNode.get(modelNodeName).outputs.size == 0) { - - for (var output of node.outputs) { - var arg_len = output._arguments.length - this.lastViewGraph._addedNode.get(modelNodeName).outputs.set(output.name, new Array(arg_len)) + for (const output of node.outputs) { + var output_list_names = [] + for (const arg of output._arguments) { + output_list_names.push(arg.name) } + this.lastViewGraph._addedNode.get(modelNodeName).outputs.set(output.name, output_list_names) } } @@ -1151,31 +1151,9 @@ view.Graph = class extends grapher.Graph { // this._addedNode.push(new view.LightNodeInfo(properties)) this._addedNode.set(modelNodeName, new view.LightNodeInfo(properties)) // console.log(this._addedNode) - - // refresh - // this.refresh_added_node() - } - // refresh_added_node() { - // this.view._graphs[0].reset_custom_added_node() - // // for (const node_info of this._addedNode.values()) { - // for (const [modelNodeName, node_info] of this._addedNode) { - // // console.log(node) - // var node = this.view._graphs[0].make_custom_add_node(node_info) - - // // padding empty array for LightNodeInfo.inputs/outputs - // for (var input of node.inputs) { - // var arg_len = input._arguments.length - // this._addedNode.get(modelNodeName).inputs.set(input.name, new Array(arg_len)) - // } - - // } - // // console.log(this.view._graphs[0].nodes) - // console.log(this._addedNode) - // } - changeNodeAttribute(modelNodeName, attributeName, targetValue) { if (this._addedNode.has(modelNodeName)) { this._addedNode.get(modelNodeName).attributes.set(attributeName, targetValue)