diff --git a/app.py b/app.py index fad078f..ec7c09a 100644 --- a/app.py +++ b/app.py @@ -22,7 +22,7 @@ def open_model(): @app.route('/download', methods=['POST']) def modify_and_download_model(): modify_info = request.get_json() - # print(modify_info) + print(modify_info) onnx_modifier.reload() # allow downloading for multiple times diff --git a/onnx_modifier.py b/onnx_modifier.py index 7033ae7..2a92a3e 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -44,6 +44,10 @@ class onnxModifier: node.name = str(node.op_type) + str(node_idx) node_idx += 1 self.node_name2module[node.name] = node + + for inp in self.graph.input: + self.node_name2module[inp.name] = inp + self.graph_input_names = [inp.name for inp in self.graph.input] for out in self.graph.output: self.node_name2module["out_" + out.name] = out # add `out_` in case the output has the same name with the last node @@ -76,7 +80,7 @@ class onnxModifier: # print('removing node {} ...'.format(node_name)) self.remove_node_by_name(node_name) - # remove node initializers (parameters) aka, keep and only keep the initializers of left nodes + # remove node initializers (parameters), aka, keep and only keep the initializers of left nodes left_node_inputs = [] for left_node in self.graph.node: left_node_inputs += left_node.input @@ -106,14 +110,20 @@ class onnxModifier: for src_name, dst_name in renamed_ios.items(): # print(src_name, dst_name) node = self.node_name2module[node_name] - # print(node.input, node.output) - for i in range(len(node.input)): - if node.input[i] == src_name: - node.input[i] = dst_name - for i in range(len(node.output)): - if node.output[i] == src_name: - node.output[i] = dst_name - + if node_name in self.graph_input_names: + node.name = dst_name + # print(node.name) + # print(node) + pass + else: + # print(node.input, node.output) + for i in range(len(node.input)): + if node.input[i] == src_name: + node.input[i] = dst_name + for i in range(len(node.output)): + if node.output[i] == src_name: + node.output[i] = dst_name + def modify_node_attr(self, node_changed_attr): # we achieve it by deleting the original node and make a (copied) new node # print(node_changed_attr) @@ -257,10 +267,10 @@ if __name__ == "__main__": def test_modify_node_io_name(): - node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}} + node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}} onnx_modifier.modify_node_io_name(node_rename_io) onnx_modifier.check_and_save_model() - # test_modify_node_io_name() + test_modify_node_io_name() def test_add_node(): node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}} @@ -278,7 +288,7 @@ if __name__ == "__main__": onnx_modifier.modify_node_attr(changed_attr) onnx_modifier.check_and_save_model() - test_change_node_attr() + # test_change_node_attr() \ No newline at end of file diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 8f3ba11..6e50c0c 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -801,7 +801,7 @@ class NodeAttributeView { sidebar.ParameterView = class { - constructor(host, list, inp_or_out, param_idx, modelNodeName) { + constructor(host, list, param_type, param_idx, modelNodeName) { this._host = host; this._list = list; this._modelNodeName = modelNodeName @@ -811,7 +811,7 @@ sidebar.ParameterView = class { // console.log(list) // for (const argument of list.arguments) { for (const [arg_idx, argument] of list.arguments.entries()) { - const item = new sidebar.ArgumentView(host, argument, inp_or_out, param_idx, arg_idx, list._name, this._modelNodeName); + const item = new sidebar.ArgumentView(host, argument, param_type, param_idx, arg_idx, list._name, this._modelNodeName); item.on('export-tensor', (sender, tensor) => { this._raise('export-tensor', tensor); }); @@ -850,10 +850,10 @@ sidebar.ParameterView = class { sidebar.ArgumentView = class { - constructor(host, argument, inp_or_out, param_index, arg_index, parameterName, modelNodeName) { + constructor(host, argument, param_type, param_index, arg_index, parameterName, modelNodeName) { this._host = host; this._argument = argument; - this._inp_or_out = inp_or_out + this._param_type = param_type this._param_index = param_index this._arg_index = arg_index this._parameterName = parameterName @@ -907,7 +907,7 @@ sidebar.ArgumentView = class { // console.log(e.target.value); // this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name); // this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value); - this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._inp_or_out, this._param_index, this._arg_index, e.target.value); + this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, e.target.value); // console.log(this._host._view._graph._renameMap); }); this._element.appendChild(arg_input); @@ -1117,8 +1117,9 @@ sidebar.ModelSidebar = class { } if (Array.isArray(graph.inputs) && graph.inputs.length > 0) { this._addHeader('Inputs'); - for (const input of graph.inputs) { - this.addArgument(input.name, input); + // for (const input of graph.inputs) { + for (const [index, input] of graph.inputs.entries()){ + this.addArgument(input.name, input, index); } } if (Array.isArray(graph.outputs) && graph.outputs.length > 0) { @@ -1150,8 +1151,9 @@ sidebar.ModelSidebar = class { this._elements.push(item.render()); } - addArgument(name, argument) { - const view = new sidebar.ParameterView(this._host, argument); + addArgument(name, argument, index) { + // const view = new sidebar.ParameterView(this._host, argument); + const view = new sidebar.ParameterView(this._host, argument, 'model_input', index, name); view.toggle(); const item = new sidebar.NameValueView(this._host, name, view); this._elements.push(item.render()); diff --git a/static/view.js b/static/view.js index abce47d..ec39b69 100644 --- a/static/view.js +++ b/static/view.js @@ -13,9 +13,6 @@ var python = python || require('./python'); var sidebar = sidebar || require('./view-sidebar'); var grapher = grapher || require('./view-grapher'); -// var onnx = onnx || require('./onnx'); - - view.View = class { constructor(host, id) { @@ -464,8 +461,10 @@ view.View = class { var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; if (active_graph && this.lastViewGraph) { this.refreshAddedNode() + this.refreshModelInputOutput() this.refreshNodeArguments() - + this.refreshNodeAttributes() + } return active_graph @@ -910,14 +909,14 @@ view.View = class { if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) { // console.log(element.name) var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) - console.log(element.original_name) - console.log(new_name) + // console.log(element.original_name) + // console.log(new_name) var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name) input.arguments[index] = arg_with_new_name - console.log(arg_with_new_name) - console.log(node) + // console.log(arg_with_new_name) + // console.log(node) } } } @@ -942,6 +941,9 @@ view.View = class { } } + } + + refreshNodeAttributes() { for (const node_name of this.lastViewGraph._changedAttributes.keys()) { var attr_change_map = this.lastViewGraph._changedAttributes.get(node_name) var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name) @@ -953,7 +955,62 @@ view.View = class { } } + } + + refreshModelInputOutput() { + // console.log("refreshModelInputOutput", this._graphs[0]) + // console.log(this.lastViewGraph._renameMap) + for (var input of this._graphs[0]._inputs) { + // console.log(input) + // console.log(input.modelNodeName) + + if (this.lastViewGraph._renameMap.get(input.modelNodeName)) { + // for model input and output, node.modelNodeName == element.original_name + var new_name = this.lastViewGraph._renameMap.get(input.modelNodeName).get(input.modelNodeName) + // console.log(new_name) + var arg_with_new_name = this._graphs[0]._context.argument(new_name, input.modelNodeName) + + input.arguments[0] = arg_with_new_name + + // change all the name of node input linked with model input meanwhile + for (var node of this._graphs[0]._nodes) { + // this node has some changed arguments + // console.log(node) + // console.log(node.modelNodeName) + // if (this.lastViewGraph._renameMap.get(node.modelNodeName)) { + for (var node_input of node.inputs) { + for (const [index, element] of node_input.arguments.entries()) { + // console.log(element.name, input.modelNodeName) + // if (element.name == input.modelNodeName) { + if (element.original_name == input.modelNodeName) { + // console.log(element.name) + // var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) + // console.log(element.original_name) + // console.log(new_name) + var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name) + + node_input.arguments[index] = arg_with_new_name + + // save the changed name into _renameMap + // as this modified _renamedMap, so refreshModelInputOutput() shoulf be called before refreshNodeArguments() + if (!this.lastViewGraph._renameMap.get(node.modelNodeName)) { + this.lastViewGraph._renameMap.set(node.modelNodeName, new Map()); + } + + var orig_arg_name = element.original_name + this.lastViewGraph._renameMap.get(node.modelNodeName).set(orig_arg_name, new_name); + + // console.log(arg_with_new_name) + // console.log(node) + } + } + } + // } + } + + } + } } }; @@ -986,6 +1043,7 @@ view.Graph = class extends grapher.Graph { const value = new view.Input(this, input); // value.name = (this._nodeKey++).toString(); value.name = input.name; // input nodes should have name + input.modelNodeName = input.name; this.setNode(value); return value; } @@ -995,6 +1053,7 @@ view.Graph = class extends grapher.Graph { const value = new view.Output(this, output, modelNodeName); // value.name = (this._nodeKey++).toString(); value.name = "out_" + output.name; // output nodes should have name + output.modelNodeName = "out_" + output.name; this.setNode(value); return value; } @@ -1175,24 +1234,35 @@ view.Graph = class extends grapher.Graph { for (const changed_node_name of this._renameMap.keys()) { var node = this._modelNodeName2ModelNode.get(changed_node_name) console.log(node) - //reset inputs - for (var input of node.inputs) { - for (var i = 0; i < input.arguments.length; ++i) { - // console.log(input.arguments[i].original_name) - if (this._renameMap.get(node.modelNodeName).get(input.arguments[i].original_name)) { - input.arguments[i] = this.view._graphs[0]._context.argument(input.arguments[i].original_name) - } - } + // console.log(typeof node) + // console.log(node.constructor.name) + if (node.arguments) { // model input or model output. Because they are purely onnx.Parameter + node.arguments[0] = this.view._graphs[0]._context.argument(node.modelNodeName) } - // reset outputs - for (var output of node.outputs) { - for (var i = 0; i < output.arguments.length; ++i) { - if (this._renameMap.get(node.modelNodeName).get(output.arguments[i].original_name)) { - output.arguments[i] = this.view._graphs[0]._context.argument(output.arguments[i].original_name) + else { // model nodes + //reset inputs + for (var input of node.inputs) { + for (var i = 0; i < input.arguments.length; ++i) { + // console.log(input.arguments[i].original_name) + if (this._renameMap.get(node.modelNodeName).get(input.arguments[i].original_name)) { + input.arguments[i] = this.view._graphs[0]._context.argument(input.arguments[i].original_name) + } + } + } + + // reset outputs + for (var output of node.outputs) { + for (var i = 0; i < output.arguments.length; ++i) { + if (this._renameMap.get(node.modelNodeName).get(output.arguments[i].original_name)) { + output.arguments[i] = this.view._graphs[0]._context.argument(output.arguments[i].original_name) + } } } + } + + } this._renameMap = new Map(); @@ -1241,7 +1311,7 @@ view.Graph = class extends grapher.Graph { } - changeNodeInputOutput(modelNodeName, parameterName, inp_or_out, param_index, arg_index, targetValue, orig_arg_name) { + changeNodeInputOutput(modelNodeName, parameterName, param_type, param_index, arg_index, targetValue, orig_arg_name) { // changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { if (this._addedNode.has(modelNodeName)) { // for custom added node if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { @@ -1260,11 +1330,22 @@ view.Graph = class extends grapher.Graph { this._renameMap.set(modelNodeName, new Map()); } - if (inp_or_out == 'input') { + if (param_type == 'model_input' || param_type == 'model_output') { + // var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[arg_index].orig_arg_name + // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments) + // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0]) + // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].orig_arg_name) + // console.log("changing model_input", orig_arg_name) + + var orig_arg_name = modelNodeName + // console.log("changing model_input", orig_arg_name) + } + + if (param_type == 'input') { var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name // console.log(orig_arg_name) } - if (inp_or_out == 'output') { + if (param_type == 'output') { var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name // console.log(orig_arg_name) }