diff --git a/static/onnx.js b/static/onnx.js index 42a406b..3743d71 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -533,35 +533,86 @@ onnx.Graph = class { // var min_input = schema.max_input var max_custom_add_input_num = Math.min(schema.max_input, 5) // set at most 5 custom_add inputs var max_custom_add_output_num = Math.min(schema.max_output, 5) // set at most 5 custom_add outputs - + + // console.log(node_info) var inputs = [] for (let i = 0; i < schema.inputs.length; ++i) { const input = schema.inputs[i] + + var node_info_input = node_info.inputs.get(input.name) + console.log(node_info_input) + var arg_list = [] if (input.list) { for (let j = 0; j < max_custom_add_input_num; ++j) { - var arg_name = 'custom_input_' + (this.custom_add_node_io_idx++).toString() + if (node_info_input && node_info_input[j]) { + var arg_name = node_info_input[j] + } + else { + var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString() + } arg_list.push(this._context.argument(arg_name)) } } else { - var arg_name = 'custom_input_' + (this.custom_add_node_io_idx++).toString() + // if (node_info_input) { + // if (!node_info_input[0]) { + // console.log('got empty') + // } + // else { + // console.log(node_info_input[0]) + // } + // } + + if (node_info_input && node_info_input[0]) { + var arg_name = node_info_input[0] + } + else { + var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString() + } arg_list = [this._context.argument(arg_name)] } + + // var arg_list = [] + // if (input.list) { + // for (let j = 0; j < max_custom_add_input_num; ++j) { + // var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString() + // arg_list.push(this._context.argument(arg_name)) + // } + // } + // else { + // var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString() + // arg_list = [this._context.argument(arg_name)] + // } + inputs.push(new onnx.Parameter(input.name, arg_list)); } + var outputs = [] for (let i = 0; i < schema.outputs.length; ++i) { const output = schema.outputs[i] + var node_info_output = node_info.outputs.get(output.name) + var arg_list = [] if (output.list) { for (let j = 0; j < max_custom_add_output_num; ++j) { - var arg_name = 'custom_output_' + (this.custom_add_node_io_idx++).toString() + if (node_info_output && node_info_output[j]) { + var arg_name = node_info_output[j] + } + else { + var arg_name = 'custom_output_' + (this._custom_add_node_io_idx++).toString() + } arg_list.push(this._context.argument(arg_name)) } } else { - var arg_name = 'custom_output_' + (this.custom_add_node_io_idx++).toString() + if (node_info_output && node_info_output[0]) { + var arg_name = node_info_output[0] + } + else { + var arg_name = 'custom_output_' + (this._custom_add_node_io_idx++).toString() + } + arg_list = [this._context.argument(arg_name)] } outputs.push(new onnx.Parameter(output.name, arg_list)); diff --git a/static/view-sidebar.js b/static/view-sidebar.js index eafb649..10ee072 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -682,7 +682,7 @@ class NodeAttributeView { var attr_input = document.createElement("INPUT"); attr_input.setAttribute("type", "text"); - attr_input.setAttribute("value", content ? content : ' '); + attr_input.setAttribute("value", content ? content : 'undefined'); attr_input.addEventListener('input', (e) => { // console.log(e.target.value); this._host._view._graph.changeNodeAttribute(this._modelNodeName, this._attributeName, e.target.value); @@ -860,7 +860,7 @@ sidebar.ArgumentView = class { // console.log(this._argument) // console.log(this._argument.name) // console.log(e.target.value); - this._host._view._graph.changeNodeInput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value); + this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value); // console.log(this._host._view._graph._renameMap); }); this._element.appendChild(arg_input); diff --git a/static/view.js b/static/view.js index 5861ea0..8a4d57a 100644 --- a/static/view.js +++ b/static/view.js @@ -875,14 +875,24 @@ view.View = class { var node = this._graphs[0].make_custom_add_node(node_info) // console.log(node) - // padding empty array for LightNodeInfo.inputs/outputs - for (var input of node.inputs) { - var arg_len = input._arguments.length - this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, new Array(arg_len)) + // padding empty array for LightNodeInfo.inputs/outputs (only when initializing) + if (this.lastViewGraph._addedNode.get(modelNodeName).inputs.size == 0) { + + for (var input of node.inputs) { + var arg_len = input._arguments.length + this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, new Array(arg_len)) + } + } + + if (this.lastViewGraph._addedNode.get(modelNodeName).outputs.size == 0) { + + for (var output of node.outputs) { + var arg_len = output._arguments.length + this.lastViewGraph._addedNode.get(modelNodeName).outputs.set(output.name, new Array(arg_len)) + } } } - // console.log(this.view._graphs[0].nodes) // console.log(this.lastViewGraph._addedNode) } @@ -1140,7 +1150,7 @@ view.Graph = class extends grapher.Graph { // console.log(properties) // this._addedNode.push(new view.LightNodeInfo(properties)) this._addedNode.set(modelNodeName, new view.LightNodeInfo(properties)) - console.log(this._addedNode) + // console.log(this._addedNode) // refresh // this.refresh_added_node() @@ -1173,10 +1183,16 @@ view.Graph = class extends grapher.Graph { // console.log(this._addedNode) } - changeNodeInput(modelNodeName, parameterName, arg_index, targetValue) { + + changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { if (this._addedNode.has(modelNodeName)) { - this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue + if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { + this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue + } + if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) { + this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue + } } // console.log(this._addedNode) } @@ -1407,7 +1423,7 @@ view.LightNodeInfo = class { this.properties = properties this.attributes = attributes || new Map() this.inputs = inputs || new Map() - this.outputs = outputs || [] + this.outputs = outputs || new Map() } }