diff --git a/docs/onnx_modifier_logo_1.png b/docs/onnx_modifier_logo.png similarity index 100% rename from docs/onnx_modifier_logo_1.png rename to docs/onnx_modifier_logo.png diff --git a/onnx_modifier.py b/onnx_modifier.py index 8c50f75..c40f430 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -139,7 +139,7 @@ class onnxModifier: def modify(self, modify_info): # print(modify_info['node_states']) - print(modify_info['node_renamed_io']) + # print(modify_info['node_renamed_io']) # print(modify_info['node_changed_attr']) # print(modify_info['added_node_info']) self.remove_node_by_node_states(modify_info['node_states']) @@ -174,12 +174,9 @@ class onnxModifier: input_name = inference_session.get_inputs()[0].name output_name = inference_session.get_outputs()[0].name - # print(input_name) - # print(output_name) - + # This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506 out = inference_session.run(None, {input_name: x})[0] - # print(out) @@ -281,7 +278,6 @@ if __name__ == "__main__": onnx_modifier.modify_node_attr(changed_attr) onnx_modifier.check_and_save_model() - test_change_node_attr() diff --git a/readme.md b/readme.md index 2dbf9d8..4ccb0d7 100644 --- a/readme.md +++ b/readme.md @@ -1,4 +1,4 @@ - + English | [简体中文](readme_zh-CN.md) @@ -12,7 +12,7 @@ Then `onnx-modifier` comes. With it, we can focus on editing the model graph in `onnx-modifier` is built based on the popular network viewer [Netron](https://github.com/lutzroeder/netron) and the lightweight web application framework [flask](https://github.com/pallets/flask). -The update log of `onnx-modifier` can be seen [here](./docs/update_log.md). Currently, the following editing operations are supported: +Currently, the following editing operations are supported: - Delete/recover nodes - Delete a single node. diff --git a/readme_zh-CN.md b/readme_zh-CN.md index 293aee6..4846eb3 100644 --- a/readme_zh-CN.md +++ b/readme_zh-CN.md @@ -1,4 +1,4 @@ - + 简体中文 | [English](readme.md) diff --git a/static/onnx.js b/static/onnx.js index 408d7a7..427cabe 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -555,15 +555,6 @@ onnx.Graph = class { } } else { - // if (node_info_input) { - // if (!node_info_input[0]) { - // console.log('got empty') - // } - // else { - // console.log(node_info_input[0]) - // } - // } - if (node_info_input && node_info_input[0]) { var arg_name = node_info_input[0] } @@ -573,18 +564,6 @@ onnx.Graph = class { arg_list = [this._context.argument(arg_name)] } - // var arg_list = [] - // if (input.list) { - // for (let j = 0; j < max_custom_add_input_num; ++j) { - // var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString() - // arg_list.push(this._context.argument(arg_name)) - // } - // } - // else { - // var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString() - // arg_list = [this._context.argument(arg_name)] - // } - inputs.push(new onnx.Parameter(input.name, arg_list)); } @@ -621,13 +600,11 @@ onnx.Graph = class { // console.log(inputs) // console.log(outputs) - console.log(node_info) + // console.log(node_info) var attributes = [] if (schema.attributes) { for (const attr of schema.attributes) { - console.log(attr) var value = node_info.attributes.get(attr.name) // modified value or null - console.log(value) attributes.push( new onnx.LightAttributeInfo( attr.name, @@ -638,8 +615,6 @@ onnx.Graph = class { ) } } - console.log(attributes[0].value) - var custom_add_node = new onnx.Node( this._context, node_info.properties.get('op_type'), @@ -691,18 +666,11 @@ onnx.Argument = class { this._annotation = annotation; this._description = description || ''; - // this._renamed = false; - // this._new_name = null; - // console.log(original_name) this.original_name = original_name || name - // console.log(this.original_name) } get name() { - // if (this._renamed) { - // return this._new_name; - // } return this._name; } @@ -868,19 +836,14 @@ onnx.Attribute = class { // console.log(attribute) this._value = attribute.value; this._type = attribute.type; - // TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_added_node. This is unsafe + // TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_added_node. This is may be unsafe // throw new onnx.Error("Unknown attribute type '" + attribute.type + "'."); } - // console.log(attribute.type) - // console.log(this._value) - // console.log(this._type) // see #L1294 GraphMetadata const metadata = context.metadata.attribute(op_type, domain, attribute.name); // console.log(metadata) if (metadata) { - // console.log(Object.prototype.hasOwnProperty.call(metadata, 'default') && this._value == metadata.default) // false - // console.log(metadata.type === 'DataType') // false if (Object.prototype.hasOwnProperty.call(metadata, 'default') && this._value == metadata.default) { this._visible = false; } @@ -1768,13 +1731,9 @@ onnx.GraphContext = class { } tensor(name) { - // console.log(this._tensors) - // console.log(name) - if (!this._tensors.has(name)) { this._tensors.set(name, { name: name }); } - // console.log(this._tensors) return this._tensors.get(name); } @@ -1792,12 +1751,9 @@ onnx.GraphContext = class { argument(name, original_name) { const tensor = this.tensor(name); - // console.log(name) - // console.log(tensor) const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; return new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name); - // // if (!this._arguments.has(name)) { // if ((!this._arguments.has(name)) || // (this._arguments.has(name) && !this._arguments.get(name).original_name == original_name) @@ -1811,22 +1767,6 @@ onnx.GraphContext = class { // return this._arguments.get(name); } - mayChangedArgument(name, rename_map) { - // console.log(name, rename_map) - if (rename_map.get(name)) { - var arg_name = rename_map.get(name) - } - else { - var arg_name = name - } - // console.log(arg_name) - - // console.log(this._arguments) - - return this.argument(arg_name) - - } - createType(type) { if (!type) { return null; @@ -1921,9 +1861,6 @@ onnx.GraphContext = class { node.input.length === 0 && node.output.length === 1 && node.output[0] && inputMap.get(node.output[0].name) === 1 && outputMap.get(node.output[0].name) === 1; const attribute = constant ? node.attribute[0] : null; - // console.log(node) - // console.log(constant) // false - // console.log(attribute) // null if (attribute && attribute.name === 'value' && attribute.type === onnx.AttributeType.TENSOR && attribute.t) { const tensor = this.tensor(node.output[0].name); tensor.initializer = new onnx.Tensor(this, attribute.t, 'Constant'); diff --git a/static/view.js b/static/view.js index e012ae9..683db45 100644 --- a/static/view.js +++ b/static/view.js @@ -479,14 +479,9 @@ view.View = class { this.UpdateAddNodeDropDown(); } this.lastViewGraph = this._graph; - // if (this.lastViewGraph) { - // // console.log(this.lastViewGraph._addedNode) - // // console.log(this.lastViewGraph._modelNodeName2State) - // } const graph = this.activeGraph; // console.log(graph.nodes) - // console.log("_updateGraph is called"); return this._timeout(100).then(() => { if (graph && graph != lastGraphs[0]) { const nodes = graph.nodes; @@ -579,8 +574,6 @@ view.View = class { const viewGraph = new view.Graph(this, model, groups, options); if (this.lastViewGraph) { - // console.log(this.lastViewGraph._modelNodeName2State) - // console.log('node state of lastViewGraph is loaded') viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State; viewGraph._renameMap = this.lastViewGraph._renameMap; viewGraph._changedAttributes = this.lastViewGraph._changedAttributes; @@ -898,16 +891,10 @@ view.View = class { } } - // console.log(this._graphs[0].nodes) - // console.log(this.lastViewGraph._addedNode) } // re-fresh node arguments in case the node inputs/outputs are changed - refreshNodeArguments() { - // console.log(this.lastViewGraph._renameMap) - // console.log(this._graphs[0]) - // console.log(this._graphs[0]._nodes) - + refreshNodeArguments() { for (var node of this._graphs[0]._nodes) { // this node has some changed arguments // console.log(node) @@ -917,11 +904,6 @@ view.View = class { // check inputs for (var input of node.inputs) { for (const [index, element] of input.arguments.entries()) { - // console.log(element) - // console.log(element.orig_arg_name) - // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))) - // input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)) - // if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) { if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) { // console.log(element.name) var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) @@ -940,11 +922,6 @@ view.View = class { // check outputs for (var output of node.outputs) { for (const [index, element] of output.arguments.entries()) { - // console.log(element) - // console.log(element.orig_arg_name) - // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))) - // input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)) - // if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) { if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) { // console.log(element.name) var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) @@ -959,25 +936,13 @@ view.View = class { } } - // check outputs - // for (var output of node.outputs) { - // for (const [index, element] of output.arguments.entries()) { - // // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))) - // output.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)) - // } - // } } } - // console.log(this._graphs[0]._context._arguments) - for (const node_name of this.lastViewGraph._changedAttributes.keys()) { var attr_change_map = this.lastViewGraph._changedAttributes.get(node_name) var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name) - // console.log(node) - // console.log(attr_change_map) - // for (const attr of node._attributes) { for (var i = 0; i < node._attributes.length; ++i) { if (attr_change_map.get(node._attributes[i].name)) { node._attributes[i]._value = attr_change_map.get(node._attributes[i].name) @@ -1077,55 +1042,13 @@ view.Graph = class extends grapher.Graph { } } - // console.log(this._renameMap) - // console.log(graph.nodes) - // console.log(this._arguments) for (var node of graph.nodes) { var viewNode = this.createNode(node); var inputs = node.inputs; for (var input of inputs) { for (var argument of input.arguments) { - if (argument.name != '' && !argument.initializer) { - - // if (viewNode.modelNodeName == "Conv3") { - // console.log("input", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name) - // console.log(graph.nodes[2]._outputs[0]._arguments[0]._name) // the linked arguments will be changed at the same time? - // console.log(graph.nodes[2]._outputs[0]._arguments[0]._new_name) // the linked arguments will be changed at the same time? - // } - - // if this argument has been renamed - // if ( - // this._renameMap.get(viewNode.modelNodeName) && - // this._renameMap.get(viewNode.modelNodeName).get(argument._name) && - // !this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name - // ) - // { - // argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); - // argument._renamed = true; - // } - // else { argument._renamed = false; } - - // if (viewNode.modelNodeName == "Conv3") { - // console.log("input", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name) - // console.log(graph.nodes[2]._outputs[0]._arguments[0]._name) // the linked arguments will be changed at the same time? - // console.log(graph.nodes[2]._outputs[0]._arguments[0]._new_name) // the linked arguments will be changed at the same time? - // } - - // if this argument has been renamed - // if ( - // this._renameMap.get(viewNode.modelNodeName) && - // this._renameMap.get(viewNode.modelNodeName).get(argument._name) - // // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name - // ) - // { - // var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name) - // } - // else { - // var arg_name = argument.name - // } - // this.createArgument(argument, arg_name).to(viewNode); - + if (argument.name != '' && !argument.initializer) { this.createArgument(argument).to(viewNode); } } @@ -1143,38 +1066,6 @@ view.Graph = class extends grapher.Graph { throw new view.Error("Invalid null argument in '" + this.model.identifier + "'."); } if (argument.name != '') { - // // if this argument has been renamed - // if ( - // this._renameMap.get(viewNode.modelNodeName) && - // this._renameMap.get(viewNode.modelNodeName).get(argument._name) && - // !this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' - // ) - // { - // argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); - // argument._renamed = true; - // // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name) - // } - // else { argument._renamed = false; } - - // if (viewNode.modelNodeName == "MaxPool2") { - // console.log("output", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name) - // } - - // if this argument has been renamed - // if ( - // this._renameMap.get(viewNode.modelNodeName) && - // this._renameMap.get(viewNode.modelNodeName).get(argument._name) - // // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' - // ) - // { - // var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); - // // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name) - // } - // else { - // var arg_name = argument.name - // } - // this.createArgument(argument, arg_name).from(viewNode); - this.createArgument(argument).from(viewNode); } } @@ -1319,15 +1210,10 @@ view.Graph = class extends grapher.Graph { var node_id = (this._add_nodeKey++).toString(); // in case input (onnx) node has no name var modelNodeName = 'custom_added_' + op_type + node_id - // console.log(op_type) - // console.log(modelNodeName) - var properties = new Map() properties.set('domain', op_domain) properties.set('op_type', op_type) properties.set('name', modelNodeName) - // console.log(properties) - // this._addedNode.push(new view.LightNodeInfo(properties)) this._addedNode.set(modelNodeName, new view.LightNodeInfo(properties)) // console.log(this._addedNode) @@ -1365,29 +1251,17 @@ view.Graph = class extends grapher.Graph { // this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph } // console.log(this._addedNode) - - // else { - // if (!this._renameMap.get(modelNodeName)) { - // this._renameMap.set(modelNodeName, new Map()); - // } - // this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue); - // // console.log(this._renameMap) - // } else { // for the nodes in the original model if (!this._renameMap.get(modelNodeName)) { this._renameMap.set(modelNodeName, new Map()); } - // var changed_node = this._modelNodeName2ModelNode.get(modelNodeName) - // console.log(inp_or_out, param_index, arg_index) - // console.log(changed_node) + if (inp_or_out == 'input') { - // var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name // console.log(orig_arg_name) } if (inp_or_out == 'output') { - // var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name // console.log(orig_arg_name) } diff --git a/utils/make_nodes.py b/utils/make_nodes.py index 6aa3ab1..8b16430 100644 --- a/utils/make_nodes.py +++ b/utils/make_nodes.py @@ -67,10 +67,7 @@ def make_attr_changed_node(node, attr_change_info): # https://github.com/onnx/onnx/blob/4e24b635c940801555bee574b4eb3a34cab9acd5/onnx/helper.py#L548 new_attr[attr.name] = onnx.helper.get_attribute_value(attr) # print(new_attr) - # print(node.input, node.output) - - # print(node) - + node = onnx.helper.make_node( op_type=node.op_type, inputs=node.input,