From 7590775f7423d5b47aff8907bc9ec1d8ca60fe2c Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Sat, 11 Jun 2022 10:40:09 +0800 Subject: [PATCH] changing node inputs/outputs in sidebar in a [in-place] manner seems done --- docs/onnx_modifier_todo.md | 6 +++++- onnx_modifier.py | 1 + static/index.js | 2 +- static/onnx.js | 28 ++++++++++++++++++++-------- static/view.js | 34 +++++++++++++++++++++++++++++++--- 5 files changed, 58 insertions(+), 13 deletions(-) diff --git a/docs/onnx_modifier_todo.md b/docs/onnx_modifier_todo.md index c8cc7f5..2f79c90 100644 --- a/docs/onnx_modifier_todo.md +++ b/docs/onnx_modifier_todo.md @@ -30,4 +30,8 @@ bug(fixed): 不可连续添加某一种类型的节点(无反应) boost: 直接使用侧边栏inputs/outputs属性框完成重命名,并提供reset功能 boost: 支持处理属性的修改 boost: 支持添加更复杂的节点 -question: 在add()函数里,为什么对conv的inputs进行遍历,只能得到X,而得不到W和B? \ No newline at end of file +question: 在add()函数里,为什么对conv的inputs进行遍历,只能得到X,而得不到W和B? + + +# 其他 +在修改节点输入输出时,建议修改方法是:把某一节点的输入,更改为另一节点的输出;而不是把某一节点的输出,改为另一节点的输入。 \ No newline at end of file diff --git a/onnx_modifier.py b/onnx_modifier.py index 3c5d2eb..4fa6856 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -122,6 +122,7 @@ class onnxModifier: def modify(self, modify_info): + print(modify_info['node_renamed_io']) self.remove_node_by_node_states(modify_info['node_states']) self.modify_node_io_name(modify_info['node_renamed_io']) self.add_node(modify_info['added_node_info']) diff --git a/static/index.js b/static/index.js index b4830de..79c52a3 100644 --- a/static/index.js +++ b/static/index.js @@ -244,7 +244,7 @@ host.BrowserHost = class { } else { // swal("Error happens!", "You are kindly to create an issue on https://github.com/ZhangGe6/onnx-modifier", "error"); - swal("Error happens!", "You can check the log or kindly create an issue on https://github.com/ZhangGe6/onnx-modifier", "error"); + swal("Error happens!", "You can check the log and kindly create an issue on https://github.com/ZhangGe6/onnx-modifier", "error"); // alert('Error happens, you can find it out or create an issue on https://github.com/ZhangGe6/onnx-modifier') } }); diff --git a/static/onnx.js b/static/onnx.js index 7294b83..4d61676 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -693,7 +693,9 @@ onnx.Argument = class { // this._renamed = false; // this._new_name = null; + console.log(original_name) this.original_name = original_name || name + console.log(this.original_name) } @@ -1787,14 +1789,24 @@ onnx.GraphContext = class { } argument(name, original_name) { - if (!this._arguments.has(name)) { - const tensor = this.tensor(name); - // console.log(name) - // console.log(tensor) - const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; - this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name)); - } - return this._arguments.get(name); + const tensor = this.tensor(name); + // console.log(name) + // console.log(tensor) + const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; + return new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name); + + + // // if (!this._arguments.has(name)) { + // if ((!this._arguments.has(name)) || + // (this._arguments.has(name) && !this._arguments.get(name).original_name == original_name) + // ) { + // const tensor = this.tensor(name); + // // console.log(name) + // // console.log(tensor) + // const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; + // this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name)); + // } + // return this._arguments.get(name); } mayChangedArgument(name, rename_map) { diff --git a/static/view.js b/static/view.js index 1b89ebd..c8e8a28 100644 --- a/static/view.js +++ b/static/view.js @@ -901,9 +901,9 @@ view.View = class { // re-fresh node arguments in case the node inputs/outputs are changed refreshNodeArguments() { - console.log(this.lastViewGraph._renameMap) + // console.log(this.lastViewGraph._renameMap) // console.log(this._graphs[0]) - console.log(this._graphs[0]._nodes) + // console.log(this._graphs[0]._nodes) // for (const node_name of this.lastViewGraph._renameMap.keys()) { // var rename_map = this.lastViewGraph._renameMap.get(node_name) @@ -931,12 +931,13 @@ view.View = class { if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) { // console.log(element.name) var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) + console.log(element.original_name) console.log(new_name) var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name) input.arguments[index] = arg_with_new_name - // console.log(arg_with_new_name) + console.log(arg_with_new_name) console.log(node) } } @@ -1263,11 +1264,38 @@ view.Graph = class extends grapher.Graph { } resetGraph() { + // reset node states for (const nodeId of this.nodes.keys()) { const node = this.node(nodeId); this._modelNodeName2State.set(node.label.modelNodeName, 'Exist') } + + console.log(this._renameMap) + // reset node inputs/outputs + for (const changed_node_name of this._renameMap.keys()) { + var node = this._modelNodeName2ModelNode.get(changed_node_name) + console.log(node) + //reset inputs + for (var input of node.inputs) { + for (var i = 0; i < input.arguments.length; ++i) { + console.log(input.arguments[i].original_name) + if (this._renameMap.get(node.modelNodeName).get(input.arguments[i].original_name)) { + input.arguments[i] = this.view._graphs[0]._context.argument(input.arguments[i].original_name) + } + } + } + + // reset outputs + // for (var output of node.outputs) { + // for (var i = 0; i < output.arguments.length; ++i) { + // if (this._renameMap.get(node.modelNodeName).get(output.arguments[i].original_name)) { + // output.arguments[index] = this._graphs[0]._context.argument(output.arguments[i].original_name) + // } + // } + // } + } this._renameMap = new Map(); + } recordRenameInfo(modelNodeName, src_name, dst_name) {