From ffe2173b40ce11cd8fd9a96712cb0f96b96fdf1f Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Sat, 11 Jun 2022 08:18:37 +0800 Subject: [PATCH] fixed shared `argument` problem --- static/onnx.js | 33 +++++-- static/view-grapher.js | 2 + static/view-sidebar.js | 32 ++++--- static/view.js | 195 ++++++++++++++++++++++++++++++----------- 4 files changed, 189 insertions(+), 73 deletions(-) diff --git a/static/onnx.js b/static/onnx.js index bdf37b2..7294b83 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -681,7 +681,7 @@ onnx.Parameter = class { onnx.Argument = class { - constructor(name, type, initializer, annotation, description) { + constructor(name, type, initializer, annotation, description, original_name) { if (typeof name !== 'string') { throw new onnx.Error("Invalid argument identifier '" + JSON.stringify(name) + "'."); } @@ -691,15 +691,16 @@ onnx.Argument = class { this._annotation = annotation; this._description = description || ''; - this._renamed = false; - this._new_name = null; + // this._renamed = false; + // this._new_name = null; + this.original_name = original_name || name } get name() { - if (this._renamed) { - return this._new_name; - } + // if (this._renamed) { + // return this._new_name; + // } return this._name; } @@ -1785,17 +1786,33 @@ onnx.GraphContext = class { return this._groups.get(name); } - argument(name) { + argument(name, original_name) { if (!this._arguments.has(name)) { const tensor = this.tensor(name); // console.log(name) // console.log(tensor) const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; - this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description)); + this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name)); } return this._arguments.get(name); } + mayChangedArgument(name, rename_map) { + // console.log(name, rename_map) + if (rename_map.get(name)) { + var arg_name = rename_map.get(name) + } + else { + var arg_name = name + } + // console.log(arg_name) + + // console.log(this._arguments) + + return this.argument(arg_name) + + } + createType(type) { if (!type) { return null; diff --git a/static/view-grapher.js b/static/view-grapher.js index 968ede4..45aaec8 100644 --- a/static/view-grapher.js +++ b/static/view-grapher.js @@ -15,6 +15,7 @@ grapher.Graph = class { // My code this._modelNodeName2ViewNode = new Map(); + this._modelNodeName2ModelNode = new Map(); this._modelNodeName2State = new Map(); this._namedEdges = new Map(); @@ -45,6 +46,7 @@ grapher.Graph = class { const modelNodeName = node.modelNodeName this._modelNodeName2ViewNode.set(modelNodeName, node); + this._modelNodeName2ModelNode.set(modelNodeName, node.value) // _modelNodeName2State save our modifications, and wil be initilized at the first graph construction only // otherwise the modfications will lost diff --git a/static/view-sidebar.js b/static/view-sidebar.js index acde4bf..b5946e7 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -182,16 +182,18 @@ sidebar.NodeSidebar = class { const inputs = node.inputs; if (inputs && inputs.length > 0) { this._addHeader('Inputs'); - for (const input of inputs) { - this._addInput(input.name, input); // 这里的input.name是小白格前面的名称(不是方格内的) + for (const [index, input] of inputs.entries()){ + // for (const input of inputs) { + this._addInput(input.name, input, index); // 这里的input.name是小白格前面的名称(不是方格内的) } } const outputs = node.outputs; if (outputs && outputs.length > 0) { this._addHeader('Outputs'); - for (const output of outputs) { - this._addOutput(output.name, output); + for (const [index, output] of outputs.entries()){ + // for (const output of outputs) { + this._addOutput(output.name, output, index); } } @@ -297,10 +299,10 @@ sidebar.NodeSidebar = class { this._elements.push(view.render()); } - _addInput(name, input) { + _addInput(name, input, param_idx) { // console.log(input) if (input.arguments.length > 0) { - const view = new sidebar.ParameterView(this._host, input, this._modelNodeName); + const view = new sidebar.ParameterView(this._host, input, 'input', param_idx, this._modelNodeName); view.on('export-tensor', (sender, tensor) => { this._raise('export-tensor', tensor); }); @@ -314,9 +316,9 @@ sidebar.NodeSidebar = class { } } - _addOutput(name, output) { + _addOutput(name, output, param_idx) { if (output.arguments.length > 0) { - const item = new sidebar.NameValueView(this._host, name, new sidebar.ParameterView(this._host, output, this._modelNodeName)); + const item = new sidebar.NameValueView(this._host, name, new sidebar.ParameterView(this._host, output, 'output', param_idx, this._modelNodeName)); this._outputs.push(item); this._elements.push(item.render()); } @@ -760,7 +762,7 @@ class NodeAttributeView { sidebar.ParameterView = class { - constructor(host, list, modelNodeName) { + constructor(host, list, inp_or_out, param_idx, modelNodeName) { this._host = host; this._list = list; this._modelNodeName = modelNodeName @@ -769,8 +771,8 @@ sidebar.ParameterView = class { // console.log(list) // for (const argument of list.arguments) { - for (const [index, argument] of list.arguments.entries()) { - const item = new sidebar.ArgumentView(host, argument, index, list._name, this._modelNodeName); + for (const [arg_idx, argument] of list.arguments.entries()) { + const item = new sidebar.ArgumentView(host, argument, inp_or_out, param_idx, arg_idx, list._name, this._modelNodeName); item.on('export-tensor', (sender, tensor) => { this._raise('export-tensor', tensor); }); @@ -809,9 +811,11 @@ sidebar.ParameterView = class { sidebar.ArgumentView = class { - constructor(host, argument, arg_index, parameterName, modelNodeName) { + constructor(host, argument, inp_or_out, param_index, arg_index, parameterName, modelNodeName) { this._host = host; this._argument = argument; + this._inp_or_out = inp_or_out + this._param_index = param_index this._arg_index = arg_index this._parameterName = parameterName this._modelNodeName = modelNodeName @@ -862,7 +866,9 @@ sidebar.ArgumentView = class { // console.log(this._argument) // console.log(this._argument.name) // console.log(e.target.value); - this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name); + // this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name); + // this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value); + this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._inp_or_out, this._param_index, this._arg_index, e.target.value); // console.log(this._host._view._graph._renameMap); }); this._element.appendChild(arg_input); diff --git a/static/view.js b/static/view.js index e8f44a0..1b89ebd 100644 --- a/static/view.js +++ b/static/view.js @@ -482,7 +482,7 @@ view.View = class { // // console.log(this.lastViewGraph._addedNode) // } const graph = this.activeGraph; - console.log(graph.nodes) + // console.log(graph.nodes) // console.log("_updateGraph is called"); return this._timeout(100).then(() => { @@ -901,16 +901,84 @@ view.View = class { // re-fresh node arguments in case the node inputs/outputs are changed refreshNodeArguments() { - console.log(this._renameMap) - for (var node in this._graphs[0].nodes) { + console.log(this.lastViewGraph._renameMap) + // console.log(this._graphs[0]) + console.log(this._graphs[0]._nodes) + + // for (const node_name of this.lastViewGraph._renameMap.keys()) { + // var rename_map = this.lastViewGraph._renameMap.get(node_name) + // var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name) + // console.log(node) + // console.log(rename_map) + // } + + + + for (var node of this._graphs[0]._nodes) { // this node has some changed arguments - if (this._renameMap.get(node.modelNodeName)) { + // console.log(node) + // console.log(node.modelNodeName) + if (this.lastViewGraph._renameMap.get(node.modelNodeName)) { + + // check inputs + for (var input of node.inputs) { + for (const [index, element] of input.arguments.entries()) { + // console.log(element) + // console.log(element.orig_arg_name) + // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))) + // input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)) + // if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) { + if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) { + // console.log(element.name) + var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) + console.log(new_name) + var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name) + + input.arguments[index] = arg_with_new_name + + // console.log(arg_with_new_name) + console.log(node) + } + } + } + // check outputs + for (var output of node.outputs) { + for (const [index, element] of output.arguments.entries()) { + // console.log(element) + // console.log(element.orig_arg_name) + // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))) + // input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)) + // if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) { + if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) { + // console.log(element.name) + var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) + console.log(new_name) + var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name) + + output.arguments[index] = arg_with_new_name + + // console.log(arg_with_new_name) + console.log(node) + } + } + } + // check outputs + // for (var output of node.outputs) { + // for (const [index, element] of output.arguments.entries()) { + // // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))) + // output.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)) + // } + // } } - } + // console.log(this._graphs[0]._context._arguments) + + + + } }; @@ -956,22 +1024,22 @@ view.Graph = class extends grapher.Graph { return value; } - // createArgument(argument) { - // const name = argument.name; - // if (!this._arguments.has(name)) { - // this._arguments.set(name, new view.Argument(this, argument)); - // } - // return this._arguments.get(name); - // } - - // for change node inputs/outputs compatibility - createArgument(argument, name) { - const arg_name = name || argument.name; - if (!this._arguments.has(arg_name)) { - this._arguments.set(arg_name, new view.Argument(this, argument)); + createArgument(argument) { + const name = argument.name; + if (!this._arguments.has(name)) { + this._arguments.set(name, new view.Argument(this, argument)); } - return this._arguments.get(arg_name); + return this._arguments.get(name); } + + // for change node inputs/outputs compatibility + // createArgument(argument, name) { + // const arg_name = name || argument.name; + // if (!this._arguments.has(arg_name)) { + // this._arguments.set(arg_name, new view.Argument(this, argument)); + // } + // return this._arguments.get(arg_name); + // } createEdge(from, to) { const value = new view.Edge(from, to); @@ -1003,8 +1071,8 @@ view.Graph = class extends grapher.Graph { } // console.log(this._renameMap) - console.log(graph.nodes) - console.log(this._arguments) + // console.log(graph.nodes) + // console.log(this._arguments) for (var node of graph.nodes) { var viewNode = this.createNode(node); @@ -1038,20 +1106,20 @@ view.Graph = class extends grapher.Graph { // } // if this argument has been renamed - if ( - this._renameMap.get(viewNode.modelNodeName) && - this._renameMap.get(viewNode.modelNodeName).get(argument._name) - // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name - ) - { - var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name) - } - else { - var arg_name = argument.name - } + // if ( + // this._renameMap.get(viewNode.modelNodeName) && + // this._renameMap.get(viewNode.modelNodeName).get(argument._name) + // // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name + // ) + // { + // var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name) + // } + // else { + // var arg_name = argument.name + // } + // this.createArgument(argument, arg_name).to(viewNode); - - this.createArgument(argument, arg_name).to(viewNode); + this.createArgument(argument).to(viewNode); } } } @@ -1086,20 +1154,21 @@ view.Graph = class extends grapher.Graph { // } // if this argument has been renamed - if ( - this._renameMap.get(viewNode.modelNodeName) && - this._renameMap.get(viewNode.modelNodeName).get(argument._name) - // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' - ) - { - var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); - // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name) - } - else { - var arg_name = argument.name - } + // if ( + // this._renameMap.get(viewNode.modelNodeName) && + // this._renameMap.get(viewNode.modelNodeName).get(argument._name) + // // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' + // ) + // { + // var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); + // // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name) + // } + // else { + // var arg_name = argument.name + // } + // this.createArgument(argument, arg_name).from(viewNode); - this.createArgument(argument, arg_name).from(viewNode); + this.createArgument(argument).from(viewNode); } } } @@ -1214,7 +1283,7 @@ view.Graph = class extends grapher.Graph { var modelNodeName = 'custom_added_' + op_type + node_id // console.log(op_type) - console.log(modelNodeName) + // console.log(modelNodeName) var properties = new Map() properties.set('domain', op_domain) @@ -1236,7 +1305,8 @@ view.Graph = class extends grapher.Graph { } - changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue, orig_arg_name) { + changeNodeInputOutput(modelNodeName, parameterName, inp_or_out, param_index, arg_index, targetValue, orig_arg_name) { + // changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { if (this._addedNode.has(modelNodeName)) { // for custom added node if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue @@ -1249,13 +1319,34 @@ view.Graph = class extends grapher.Graph { } // console.log(this._addedNode) - else { // for the nodes in the original model + // else { // for the nodes in the original model + // if (!this._renameMap.get(modelNodeName)) { + // this._renameMap.set(modelNodeName, new Map()); + // } + // this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue); + // // console.log(this._renameMap) + // } + + else { if (!this._renameMap.get(modelNodeName)) { this._renameMap.set(modelNodeName, new Map()); } - this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue); - // console.log(this._renameMap) + // var changed_node = this._modelNodeName2ModelNode.get(modelNodeName) + // console.log(inp_or_out, param_index, arg_index) + // console.log(changed_node) + if (inp_or_out == 'input') { + // var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name + var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name + // console.log(orig_arg_name) + } + if (inp_or_out == 'output') { + // var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name + var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name + // console.log(orig_arg_name) + } + this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue); + console.log(this._renameMap) } this.view._updateGraph()