From f9aa8840c391d9af433ce7dcad67628fddbda6ea Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Fri, 10 Jun 2022 16:48:33 +0800 Subject: [PATCH] this is a backup before making some challenging changes for refresh node arguments --- docs/onnx_modifier_todo.md | 12 ++- static/onnx.js | 4 +- static/view-sidebar.js | 2 +- static/view.js | 152 ++++++++++++++++++++++++++++--------- 4 files changed, 131 insertions(+), 39 deletions(-) diff --git a/docs/onnx_modifier_todo.md b/docs/onnx_modifier_todo.md index a22afb1..c8cc7f5 100644 --- a/docs/onnx_modifier_todo.md +++ b/docs/onnx_modifier_todo.md @@ -13,7 +13,6 @@ add node (add preprocess nodes): https://zhuanlan.zhihu.com/p/394395167 topk: https://github.com/onnx/onnx/issues/2921 - # done remove layer: https://github.com/onnx/onnx/issues/2638 @@ -22,4 +21,13 @@ remove layer: https://github.com/onnx/onnx/issues/2638 # 或许可以帮助 -http://yyixx.com/docs/algo/onnx/ \ No newline at end of file +http://yyixx.com/docs/algo/onnx/ + + +# 待做的 + +bug(fixed): 不可连续添加某一种类型的节点(无反应) +boost: 直接使用侧边栏inputs/outputs属性框完成重命名,并提供reset功能 +boost: 支持处理属性的修改 +boost: 支持添加更复杂的节点 +question: 在add()函数里,为什么对conv的inputs进行遍历,只能得到X,而得不到W和B? \ No newline at end of file diff --git a/static/onnx.js b/static/onnx.js index 022491c..bdf37b2 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -521,7 +521,7 @@ onnx.Graph = class { return 'graph(' + this.name + ')'; } - make_custom_add_node(node_info) { + make_custom_added_node(node_info) { // type of node_info == LightNodeInfo const schema = this._context.metadata.type(node_info.properties.get('op_type'), node_info.properties.get('domain')); // console.log(schema) @@ -865,7 +865,7 @@ onnx.Attribute = class { // console.log(attribute) this._value = attribute.value; this._type = attribute.type; - // TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_add_node. This is unsafe + // TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_added_node. This is unsafe // throw new onnx.Error("Unknown attribute type '" + attribute.type + "'."); } // console.log(attribute.type) diff --git a/static/view-sidebar.js b/static/view-sidebar.js index deffe5c..acde4bf 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -862,7 +862,7 @@ sidebar.ArgumentView = class { // console.log(this._argument) // console.log(this._argument.name) // console.log(e.target.value); - this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value); + this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name); // console.log(this._host._view._graph._renameMap); }); this._element.appendChild(arg_input); diff --git a/static/view.js b/static/view.js index 38a64df..e8f44a0 100644 --- a/static/view.js +++ b/static/view.js @@ -461,6 +461,7 @@ view.View = class { var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; if (active_graph && this.lastViewGraph) { this.refreshAddedNode() + this.refreshNodeArguments() } return active_graph @@ -481,8 +482,7 @@ view.View = class { // // console.log(this.lastViewGraph._addedNode) // } const graph = this.activeGraph; - // console.log(graph.nodes) - + console.log(graph.nodes) // console.log("_updateGraph is called"); return this._timeout(100).then(() => { @@ -582,6 +582,7 @@ view.View = class { viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State; viewGraph._renameMap = this.lastViewGraph._renameMap; viewGraph._addedNode = this.lastViewGraph._addedNode; + viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey // console.log(viewGraph._renameMap); // console.log(viewGraph._modelNodeName2State) } @@ -867,12 +868,13 @@ view.View = class { } + // re-generate the added node according to _addedNode refreshAddedNode() { this._graphs[0].reset_custom_added_node() // for (const node_info of this._addedNode.values()) { for (const [modelNodeName, node_info] of this.lastViewGraph._addedNode) { // console.log(node_info) - var node = this._graphs[0].make_custom_add_node(node_info) + var node = this._graphs[0].make_custom_added_node(node_info) // console.log(node) for (const input of node.inputs) { @@ -897,6 +899,20 @@ view.View = class { // console.log(this.lastViewGraph._addedNode) } + // re-fresh node arguments in case the node inputs/outputs are changed + refreshNodeArguments() { + console.log(this._renameMap) + for (var node in this._graphs[0].nodes) { + // this node has some changed arguments + if (this._renameMap.get(node.modelNodeName)) { + + + } + + } + + } + }; view.Graph = class extends grapher.Graph { @@ -915,6 +931,7 @@ view.Graph = class extends grapher.Graph { createNode(node) { var node_id = (this._nodeKey++).toString(); // in case input (onnx) node has no name var modelNodeName = node.name ? node.name : node.type.name + node_id + node.modelNodeName = modelNodeName // this will take in-place effect for the onnx.Node in onnx.Graph, which can make it more convenient if we want to find a node in onnx.Graph later const value = new view.Node(this, node, modelNodeName); value.name = node_id; @@ -939,12 +956,21 @@ view.Graph = class extends grapher.Graph { return value; } - createArgument(argument) { - const name = argument.name; - if (!this._arguments.has(name)) { - this._arguments.set(name, new view.Argument(this, argument)); + // createArgument(argument) { + // const name = argument.name; + // if (!this._arguments.has(name)) { + // this._arguments.set(name, new view.Argument(this, argument)); + // } + // return this._arguments.get(name); + // } + + // for change node inputs/outputs compatibility + createArgument(argument, name) { + const arg_name = name || argument.name; + if (!this._arguments.has(arg_name)) { + this._arguments.set(arg_name, new view.Argument(this, argument)); } - return this._arguments.get(name); + return this._arguments.get(arg_name); } createEdge(from, to) { @@ -976,26 +1002,56 @@ view.Graph = class extends grapher.Graph { } } - for (const node of graph.nodes) { - const viewNode = this.createNode(node); + // console.log(this._renameMap) + console.log(graph.nodes) + console.log(this._arguments) + for (var node of graph.nodes) { + var viewNode = this.createNode(node); - const inputs = node.inputs; - for (const input of inputs) { - for (const argument of input.arguments) { + var inputs = node.inputs; + for (var input of inputs) { + for (var argument of input.arguments) { if (argument.name != '' && !argument.initializer) { + + // if (viewNode.modelNodeName == "Conv3") { + // console.log("input", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name) + // console.log(graph.nodes[2]._outputs[0]._arguments[0]._name) // the linked arguments will be changed at the same time? + // console.log(graph.nodes[2]._outputs[0]._arguments[0]._new_name) // the linked arguments will be changed at the same time? + // } + + // if this argument has been renamed + // if ( + // this._renameMap.get(viewNode.modelNodeName) && + // this._renameMap.get(viewNode.modelNodeName).get(argument._name) && + // !this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name + // ) + // { + // argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); + // argument._renamed = true; + // } + // else { argument._renamed = false; } + + // if (viewNode.modelNodeName == "Conv3") { + // console.log("input", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name) + // console.log(graph.nodes[2]._outputs[0]._arguments[0]._name) // the linked arguments will be changed at the same time? + // console.log(graph.nodes[2]._outputs[0]._arguments[0]._new_name) // the linked arguments will be changed at the same time? + // } + // if this argument has been renamed if ( this._renameMap.get(viewNode.modelNodeName) && - this._renameMap.get(viewNode.modelNodeName).get(argument.name) && - !this._renameMap.get(viewNode.modelNodeName).get(argument.name) == '' // in case user clear the input name + this._renameMap.get(viewNode.modelNodeName).get(argument._name) + // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name ) { - argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument.name); - argument._renamed = true; + var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name) } - else { argument._renamed = false; } + else { + var arg_name = argument.name + } + - this.createArgument(argument).to(viewNode); + this.createArgument(argument, arg_name).to(viewNode); } } } @@ -1006,25 +1062,44 @@ view.Graph = class extends grapher.Graph { outputs = chainOutputs; } } - for (const output of outputs) { - for (const argument of output.arguments) { + for (var output of outputs) { + for (var argument of output.arguments) { if (!argument) { throw new view.Error("Invalid null argument in '" + this.model.identifier + "'."); } if (argument.name != '') { + // // if this argument has been renamed + // if ( + // this._renameMap.get(viewNode.modelNodeName) && + // this._renameMap.get(viewNode.modelNodeName).get(argument._name) && + // !this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' + // ) + // { + // argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); + // argument._renamed = true; + // // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name) + // } + // else { argument._renamed = false; } + + // if (viewNode.modelNodeName == "MaxPool2") { + // console.log("output", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name) + // } + // if this argument has been renamed if ( this._renameMap.get(viewNode.modelNodeName) && - this._renameMap.get(viewNode.modelNodeName).get(argument.name) && - !this._renameMap.get(viewNode.modelNodeName).get(argument.name) == '' + this._renameMap.get(viewNode.modelNodeName).get(argument._name) + // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' ) { - argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument.name); - argument._renamed = true; + var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); + // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name) + } + else { + var arg_name = argument.name } - else { argument._renamed = false; } - this.createArgument(argument).from(viewNode); + this.createArgument(argument, arg_name).from(viewNode); } } } @@ -1134,15 +1209,12 @@ view.Graph = class extends grapher.Graph { } add_node(op_domain, op_type) { - // node_name: the added op name - // parent_node_name: parent modelNodeName - // console.log(node_name) var node_id = (this._add_nodeKey++).toString(); // in case input (onnx) node has no name var modelNodeName = 'custom_added_' + op_type + node_id // console.log(op_type) - // console.log(modelNodeName) + console.log(modelNodeName) var properties = new Map() properties.set('domain', op_domain) @@ -1164,8 +1236,8 @@ view.Graph = class extends grapher.Graph { } - changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { - if (this._addedNode.has(modelNodeName)) { + changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue, orig_arg_name) { + if (this._addedNode.has(modelNodeName)) { // for custom added node if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue } @@ -1173,9 +1245,21 @@ view.Graph = class extends grapher.Graph { if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) { this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue } - this.view._updateGraph() // otherwise the changes can not be updated without manully update graph + // this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph } // console.log(this._addedNode) + + else { // for the nodes in the original model + if (!this._renameMap.get(modelNodeName)) { + this._renameMap.set(modelNodeName, new Map()); + } + this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue); + // console.log(this._renameMap) + + } + + this.view._updateGraph() + }