diff --git a/app.py b/app.py index ec7c09a..fad078f 100644 --- a/app.py +++ b/app.py @@ -22,7 +22,7 @@ def open_model(): @app.route('/download', methods=['POST']) def modify_and_download_model(): modify_info = request.get_json() - print(modify_info) + # print(modify_info) onnx_modifier.reload() # allow downloading for multiple times diff --git a/onnx_modifier.py b/onnx_modifier.py index 2a92a3e..db7735e 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -114,7 +114,9 @@ class onnxModifier: node.name = dst_name # print(node.name) # print(node) - pass + # pass + elif node_name in self.graph_output_names: + node.name = dst_name else: # print(node.input, node.output) for i in range(len(node.input)): diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 6e50c0c..9c69457 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -1119,13 +1119,16 @@ sidebar.ModelSidebar = class { this._addHeader('Inputs'); // for (const input of graph.inputs) { for (const [index, input] of graph.inputs.entries()){ - this.addArgument(input.name, input, index); + this.addArgument(input.name, input, index, 'model_input'); + // this.addArgument(input.modelNodeName, input, index, 'model_input'); } } if (Array.isArray(graph.outputs) && graph.outputs.length > 0) { this._addHeader('Outputs'); - for (const output of graph.outputs) { - this.addArgument(output.name, output); + // for (const output of graph.outputs) { + for (const [index, output] of graph.outputs.entries()){ + this.addArgument(output.name, output, index, 'model_output'); + // this.addArgument(output.modelNodeName, output, index, 'model_output'); } } } @@ -1151,9 +1154,9 @@ sidebar.ModelSidebar = class { this._elements.push(item.render()); } - addArgument(name, argument, index) { + addArgument(name, argument, index, arg_type) { // const view = new sidebar.ParameterView(this._host, argument); - const view = new sidebar.ParameterView(this._host, argument, 'model_input', index, name); + const view = new sidebar.ParameterView(this._host, argument, arg_type, index, name); view.toggle(); const item = new sidebar.NameValueView(this._host, name, view); this._elements.push(item.render()); diff --git a/static/view.js b/static/view.js index ec39b69..362e9b3 100644 --- a/static/view.js +++ b/static/view.js @@ -963,12 +963,12 @@ view.View = class { for (var input of this._graphs[0]._inputs) { // console.log(input) // console.log(input.modelNodeName) - - if (this.lastViewGraph._renameMap.get(input.modelNodeName)) { + var input_orig_name = input.arguments[0].original_name + if (this.lastViewGraph._renameMap.get(input_orig_name)) { // for model input and output, node.modelNodeName == element.original_name - var new_name = this.lastViewGraph._renameMap.get(input.modelNodeName).get(input.modelNodeName) + var new_name = this.lastViewGraph._renameMap.get(input_orig_name).get(input_orig_name) // console.log(new_name) - var arg_with_new_name = this._graphs[0]._context.argument(new_name, input.modelNodeName) + var arg_with_new_name = this._graphs[0]._context.argument(new_name, input_orig_name) input.arguments[0] = arg_with_new_name @@ -982,7 +982,7 @@ view.View = class { for (const [index, element] of node_input.arguments.entries()) { // console.log(element.name, input.modelNodeName) // if (element.name == input.modelNodeName) { - if (element.original_name == input.modelNodeName) { + if (element.original_name == input_orig_name) { // console.log(element.name) // var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) // console.log(element.original_name) @@ -1009,6 +1009,62 @@ view.View = class { } + } + } + + for (var output of this._graphs[0]._outputs) { + // console.log(output) + // console.log(this.lastViewGraph._renameMap) + // console.log(input.modelNodeName) + + // if (this.lastViewGraph._renameMap.get(output.modelNodeName)) { + var output_orig_name = output.arguments[0].original_name + if (this.lastViewGraph._renameMap.get('out_' + output_orig_name)) { + // for model input and output, node.modelNodeName == element.original_name + var new_name = this.lastViewGraph._renameMap.get('out_' + output_orig_name).get(output_orig_name) + console.log(new_name) + var arg_with_new_name = this._graphs[0]._context.argument(new_name, output_orig_name) + + output.arguments[0] = arg_with_new_name + + // change all the name of node input linked with model input meanwhile + for (var node of this._graphs[0]._nodes) { + // this node has some changed arguments + // console.log(node) + // console.log(node.modelNodeName) + // if (this.lastViewGraph._renameMap.get(node.modelNodeName)) { + for (var node_output of node.outputs) { + for (const [index, element] of node_output.arguments.entries()) { + // console.log(element.name, input.modelNodeName) + // if (element.name == input.modelNodeName) { + // console.log(element.original_name, output.modelNodeName) + if (element.original_name == output_orig_name) { + console.log(element.original_name) + // var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) + // console.log(element.original_name) + // console.log(new_name) + var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name) + + node_output.arguments[index] = arg_with_new_name + + // save the changed name into _renameMap + // as this modified _renamedMap, so refreshModelInputOutput() shoulf be called before refreshNodeArguments() + if (!this.lastViewGraph._renameMap.get(node.modelNodeName)) { + this.lastViewGraph._renameMap.set(node.modelNodeName, new Map()); + } + + var orig_arg_name = element.original_name + this.lastViewGraph._renameMap.get(node.modelNodeName).set(orig_arg_name, new_name); + + // console.log(arg_with_new_name) + // console.log(node) + } + } + } + // } + } + + } } } @@ -1040,9 +1096,17 @@ view.Graph = class extends grapher.Graph { } createInput(input) { - const value = new view.Input(this, input); + if (this._renameMap.get(input.name)) { + var show_name = this._renameMap.get(input.name).get(input.name); + } + else { + var show_name = input.name; // input nodes should have name + } + const value = new view.Input(this, input, show_name); // value.name = (this._nodeKey++).toString(); - value.name = input.name; // input nodes should have name + + value.name = input.name; + // console.log(value.name) input.modelNodeName = input.name; this.setNode(value); return value; @@ -1050,7 +1114,13 @@ view.Graph = class extends grapher.Graph { createOutput(output) { var modelNodeName = "out_" + output.name; // in case the output has the same name with the last node - const value = new view.Output(this, output, modelNodeName); + if (this._renameMap.get(modelNodeName)) { + var show_name = this._renameMap.get(modelNodeName).get(output.name); + } + else { + var show_name = output.name; // input nodes should have name + } + const value = new view.Output(this, output, modelNodeName, show_name); // value.name = (this._nodeKey++).toString(); value.name = "out_" + output.name; // output nodes should have name output.modelNodeName = "out_" + output.name; @@ -1326,19 +1396,29 @@ view.Graph = class extends grapher.Graph { // console.log(this._addedNode) else { // for the nodes in the original model - if (!this._renameMap.get(modelNodeName)) { - this._renameMap.set(modelNodeName, new Map()); - } - if (param_type == 'model_input' || param_type == 'model_output') { + if (param_type == 'model_input') { // var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[arg_index].orig_arg_name // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments) // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0]) // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].orig_arg_name) // console.log("changing model_input", orig_arg_name) - var orig_arg_name = modelNodeName + // var orig_arg_name = modelNodeName + var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name // console.log("changing model_input", orig_arg_name) + // console.log(param_type, orig_arg_name) + } + + if (param_type == 'model_output') { + // console.log(this._modelNodeName2ModelNode.get('out_' + modelNodeName)) + // console.log(this._modelNodeName2ModelNode.get('out_' + modelNodeName).arguments[0].original_name) + modelNodeName = 'out_' + modelNodeName + console.log(modelNodeName) + var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name + console.log(orig_arg_name) + // console.log("changing model_input", orig_arg_name) + // console.log(param_type, orig_arg_name) } if (param_type == 'input') { @@ -1350,6 +1430,9 @@ view.Graph = class extends grapher.Graph { // console.log(orig_arg_name) } + if (!this._renameMap.get(modelNodeName)) { + this._renameMap.set(modelNodeName, new Map()); + } this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue); console.log(this._renameMap) } @@ -1523,13 +1606,14 @@ view.Node = class extends grapher.Node { view.Input = class extends grapher.Node { - constructor(context, value) { + constructor(context, value, show_name) { super(); this.context = context; this.value = value; view.Input.counter = view.Input.counter || 0; const types = value.arguments.map((argument) => argument.type || '').join('\n'); - let name = value.name || ''; + // let name = value.name || ''; + let name = show_name this.modelNodeName = value.name if (name.length > 16) { name = name.split('/').pop(); @@ -1555,12 +1639,13 @@ view.Input = class extends grapher.Node { view.Output = class extends grapher.Node { - constructor(context, value, modelNodeName) { + constructor(context, value, modelNodeName, show_name) { super(); this.context = context; this.value = value; const types = value.arguments.map((argument) => argument.type || '').join('\n'); - let name = value.name || ''; + // let name = value.name || ''; + let name = show_name; this.modelNodeName = modelNodeName if (name.length > 16) { name = name.split('/').pop();