diff --git a/onnx_modifier.py b/onnx_modifier.py index 6b310cf..d08c110 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -98,8 +98,6 @@ class onnxModifier: if len(const_node_left_output) == 0: self.graph.node.remove(self.node_name2module[left_node.name]) # self.initializer.remove(self.initializer_name2module[init_name]) - else: - left_node.output = const_node_left_output def modify_node_io_name(self, node_renamed_io): diff --git a/readme.md b/readme.md index 0a4821a..28fb59c 100644 --- a/readme.md +++ b/readme.md @@ -19,6 +19,7 @@ The update log of `onnx-modifier` can be seen [here](./docs/update_log.md). Curr - Delete a node and all the nodes rooted on it. - Recover a deleted node. - Rename the name of node inputs/outputs +- Edit the attribute of nodes - Add new nodes (experimental) @@ -86,7 +87,6 @@ By changing the input/output name of nodes, we can change the model forward path Using `onnx-modifier`, we can achieve this by simply enter a new name for node inputs/outputs in its corresponding input placeholder. The graph topology is updated automatically and instantly, according to the new names. - For example, Now we want remove the preprocess operators (`Sub->Mul->Sub->Transpose`) shown in the following figure. We can 1. Click on the 1st `Conv` node, rename its input (X) as *serving_default_input:0* (the output of node `data_0`). @@ -97,6 +97,12 @@ For example, Now we want remove the preprocess operators (`Sub->Mul->Sub->Trans rename_node_io +## Edit the attribute of nodes + +Change the original attribute to a new value, then we are done. + + + ## Add new node Sometimes we want to add new nodes into the exsited model. `onnx-modifier` supports this feature experimentally now. diff --git a/static/onnx.js b/static/onnx.js index 2de8d14..408d7a7 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -621,13 +621,13 @@ onnx.Graph = class { // console.log(inputs) // console.log(outputs) - // console.log(node_info) + console.log(node_info) var attributes = [] if (schema.attributes) { for (const attr of schema.attributes) { - // console.log(attr) + console.log(attr) var value = node_info.attributes.get(attr.name) // modified value or null - // console.log(value) + console.log(value) attributes.push( new onnx.LightAttributeInfo( attr.name, @@ -638,7 +638,7 @@ onnx.Graph = class { ) } } - // console.log(attributes) + console.log(attributes[0].value) var custom_add_node = new onnx.Node( this._context, @@ -918,7 +918,9 @@ onnx.LightAttributeInfo = class { this.name = name; this.description = description; this.type = type; - this.value = value || null; + // this.value = value || null; + // console.log(value, value || null) // TODO: amazing output: 0, null + this.value = value } } diff --git a/static/view.js b/static/view.js index f963356..e012ae9 100644 --- a/static/view.js +++ b/static/view.js @@ -1284,7 +1284,7 @@ view.Graph = class extends grapher.Graph { //reset inputs for (var input of node.inputs) { for (var i = 0; i < input.arguments.length; ++i) { - console.log(input.arguments[i].original_name) + // console.log(input.arguments[i].original_name) if (this._renameMap.get(node.modelNodeName).get(input.arguments[i].original_name)) { input.arguments[i] = this.view._graphs[0]._context.argument(input.arguments[i].original_name) } @@ -1292,16 +1292,19 @@ view.Graph = class extends grapher.Graph { } // reset outputs - // for (var output of node.outputs) { - // for (var i = 0; i < output.arguments.length; ++i) { - // if (this._renameMap.get(node.modelNodeName).get(output.arguments[i].original_name)) { - // output.arguments[index] = this._graphs[0]._context.argument(output.arguments[i].original_name) - // } - // } - // } + for (var output of node.outputs) { + for (var i = 0; i < output.arguments.length; ++i) { + if (this._renameMap.get(node.modelNodeName).get(output.arguments[i].original_name)) { + output.arguments[i] = this.view._graphs[0]._context.argument(output.arguments[i].original_name) + } + } + } } this._renameMap = new Map(); + // clear custom added nodes + this._addedNode = new Map() + this.view._graphs[0].reset_custom_added_node() } recordRenameInfo(modelNodeName, src_name, dst_name) {