From 323c85fb1f700c52844b00174863dad71f0d20e7 Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Sat, 30 Apr 2022 18:53:46 +0800 Subject: [PATCH] modify_node_io_name() basically done --- app.py | 7 ++++--- onnx_modifier.py | 43 ++++++++++++++++++++++++++++++++---------- static/index.js | 31 ++++++++++++++++++++++++++---- static/view-sidebar.js | 2 +- static/view.js | 6 ++++-- 5 files changed, 69 insertions(+), 20 deletions(-) diff --git a/app.py b/app.py index 618c8d7..eb07487 100644 --- a/app.py +++ b/app.py @@ -20,11 +20,12 @@ def return_file(): @app.route('/download', methods=['POST']) def modify_and_download_model(): - node_states = json.loads(request.get_json()) + modify_info = request.get_json() + # print(modify_info) - # print(node_states) onnx_modifier.reload() # allow for downloading for multiple times - onnx_modifier.remove_node_by_node_states(node_states) + onnx_modifier.remove_node_by_node_states(modify_info['node_states']) + onnx_modifier.modify_node_io_name(modify_info['node_renamed_io']) onnx_modifier.check_and_save_model() diff --git a/onnx_modifier.py b/onnx_modifier.py index 7eee1ce..6fe4f6d 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -10,6 +10,7 @@ class onnxModifier: def __init__(self, model_name, model_proto): self.model_name = model_name self.model_proto_backup = model_proto + self.reload() @classmethod def from_model_path(cls, model_path): @@ -78,12 +79,29 @@ class onnxModifier: for init_name in self.initilizer_name2module.keys(): if not init_name in left_node_inputs: self.initializer.remove(self.initilizer_name2module[init_name]) + + def modify_node_io_name(self, node_renamed_io): + # print(node_renamed_io) + for node_name in node_renamed_io.keys(): + renamed_ios = node_renamed_io[node_name] + for src_name, dst_name in renamed_ios.items(): + # print(src_name, dst_name) + node = self.node_name2module[node_name] + # print(node.input, node.output) + for i in range(len(node.input)): + if node.input[i] == src_name: + node.input[i] = dst_name + for i in range(len(node.output)): + if node.output[i] == src_name: + node.output[i] = dst_name + # print(node.input, node.output) def check_and_save_model(self, save_dir='./res_onnx'): save_path = os.path.join(save_dir, 'modified_' + self.model_name) - # onnx.checker.check_model(self.model_proto) + onnx.checker.check_model(self.model_proto) onnx.save(self.model_proto, save_path) + def inference(self): # model_proto_bytes = onnx._serialize(model_proto_from_stream) # inference_session = rt.InferenceSession(model_proto_bytes) @@ -91,9 +109,9 @@ class onnxModifier: if __name__ == "__main__": - # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" - model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) def remove_node_by_node_states(): @@ -134,7 +152,7 @@ if __name__ == "__main__": print(inp.name) onnx_modifier.check_and_save_model() - remove_node_by_node_states() + # remove_node_by_node_states() def explore_basic(): print(type(onnx_modifier.model_proto.graph.initializer)) @@ -143,15 +161,20 @@ if __name__ == "__main__": print(len(onnx_modifier.model_proto.graph.node)) print(len(onnx_modifier.model_proto.graph.initializer)) - # for node in onnx_modifier.model_proto.graph.node: - # print(node.name) - # print(node.input) - # print() + for node in onnx_modifier.model_proto.graph.node: + print(node.name) + print(node.input) + print() # for initializer in onnx_modifier.model_proto.graph.initializer: # print(initializer.name) # print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale']) - + pass # explore_basic() - \ No newline at end of file + def test_modify_node_io_name(): + node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}} + onnx_modifier.modify_node_io_name(node_rename_io) + onnx_modifier.check_and_save_model() + test_modify_node_io_name() + \ No newline at end of file diff --git a/static/index.js b/static/index.js index 48ce9b8..807701b 100644 --- a/static/index.js +++ b/static/index.js @@ -214,7 +214,8 @@ host.BrowserHost = class { const downloadButton = this.document.getElementById('download-graph'); downloadButton.addEventListener('click', () => { - // console.log(this._host._view._graph._modelNodeName2State) + console.log(this._view._graph._modelNodeName2State) + console.log(this._view._graph._renameMap) // https://healeycodes.com/talking-between-languages fetch('/download', { // Declare what type of data we're sending @@ -224,8 +225,13 @@ host.BrowserHost = class { // Specify the method method: 'POST', // https://blog.csdn.net/Crazy_SunShine/article/details/80624366 - body: JSON.stringify( - this._mapToJson(this._view._graph._modelNodeName2State) + body: JSON.stringify({ + // 'node_states' : this._mapToJson(this._view._graph._modelNodeName2State), + // 'node_renamed_io' : this._twoLevelMapToJson(this._view._graph._renameMap), + 'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State), + 'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap), + } + ) }).then(function (response) { return response.text(); @@ -605,16 +611,33 @@ host.BrowserHost = class { _strMapToObj(strMap){ let obj = Object.create(null); - for (let[k, v] of strMap) { + for (let [k, v] of strMap) { obj[k] = v; } return obj; } + // {key1:val1, key2:val2, ...} => Json _mapToJson(map) { return JSON.stringify(this._strMapToObj(map)); } + // https://www.xul.fr/javascript/map-and-object.php + mapToObjectRec(m) { + let lo = {} + for(let[k,v] of m) { + if(v instanceof Map) { + lo[k] = this.mapToObjectRec(v) + } + else { + lo[k] = v + } + } + return lo + } + + + }; host.Dropdown = class { diff --git a/static/view-sidebar.js b/static/view-sidebar.js index fcea5ec..8ec27aa 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -247,7 +247,7 @@ sidebar.NodeSidebar = class { newNameElement.setAttribute('value', this._host._view._graph._renameMap.get(this._modelNodeName).get(argument.name)); } newNameElement.addEventListener('input', (e) => { - // console.log(e.target.value); + console.log(e.target.value); this._host._view._graph.recordRenameInfo(this._modelNodeName, argument.name, e.target.value); // console.log(this._host._view._graph._renameMap); diff --git a/static/view.js b/static/view.js index bae2c04..7550583 100644 --- a/static/view.js +++ b/static/view.js @@ -958,7 +958,8 @@ view.Graph = class extends grapher.Graph { // if this argument has been renamed if ( this._renameMap.get(viewNode.modelNodeName) && - this._renameMap.get(viewNode.modelNodeName).get(argument.name) + this._renameMap.get(viewNode.modelNodeName).get(argument.name) && + !this._renameMap.get(viewNode.modelNodeName).get(argument.name) == '' // in case user clear the input name ) { // argument.name = this._renameMap.get(viewNode.modelNodeName).get(argument.name); @@ -992,7 +993,8 @@ view.Graph = class extends grapher.Graph { // if this argument has been renamed if ( this._renameMap.get(viewNode.modelNodeName) && - this._renameMap.get(viewNode.modelNodeName).get(argument.name) + this._renameMap.get(viewNode.modelNodeName).get(argument.name) && + !this._renameMap.get(viewNode.modelNodeName).get(argument.name) == '' ) { // console.log(argument.name)