polish the code and push to github

1123
ZhangGe6 4 years ago
parent 03d8ea4d80
commit 7861092c55

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 21 KiB

@ -139,7 +139,7 @@ class onnxModifier:
def modify(self, modify_info):
# print(modify_info['node_states'])
print(modify_info['node_renamed_io'])
# print(modify_info['node_renamed_io'])
# print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info'])
self.remove_node_by_node_states(modify_info['node_states'])
@ -174,12 +174,9 @@ class onnxModifier:
input_name = inference_session.get_inputs()[0].name
output_name = inference_session.get_outputs()[0].name
# print(input_name)
# print(output_name)
# This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506
out = inference_session.run(None, {input_name: x})[0]
# print(out)
@ -281,7 +278,6 @@ if __name__ == "__main__":
onnx_modifier.modify_node_attr(changed_attr)
onnx_modifier.check_and_save_model()
test_change_node_attr()

@ -1,4 +1,4 @@
<img src="./docs/onnx_modifier_logo_1.png" style="zoom: 60%;" />
<img src="./docs/onnx_modifier_logo.png" style="zoom: 60%;" />
English | [简体中文](readme_zh-CN.md)
@ -12,7 +12,7 @@ Then `onnx-modifier` comes. With it, we can focus on editing the model graph in
`onnx-modifier` is built based on the popular network viewer [Netron](https://github.com/lutzroeder/netron) and the lightweight web application framework [flask](https://github.com/pallets/flask).
The update log of `onnx-modifier` can be seen [here](./docs/update_log.md). Currently, the following editing operations are supported:
Currently, the following editing operations are supported:
- Delete/recover nodes
- Delete a single node.

@ -1,4 +1,4 @@
<img src="./docs/onnx_modifier_logo_1.png" style="zoom: 60%;" />
<img src="./docs/onnx_modifier_logo.png" style="zoom: 60%;" />
简体中文 | [English](readme.md)

@ -555,15 +555,6 @@ onnx.Graph = class {
}
}
else {
// if (node_info_input) {
// if (!node_info_input[0]) {
// console.log('got empty')
// }
// else {
// console.log(node_info_input[0])
// }
// }
if (node_info_input && node_info_input[0]) {
var arg_name = node_info_input[0]
}
@ -573,18 +564,6 @@ onnx.Graph = class {
arg_list = [this._context.argument(arg_name)]
}
// var arg_list = []
// if (input.list) {
// for (let j = 0; j < max_custom_add_input_num; ++j) {
// var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString()
// arg_list.push(this._context.argument(arg_name))
// }
// }
// else {
// var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString()
// arg_list = [this._context.argument(arg_name)]
// }
inputs.push(new onnx.Parameter(input.name, arg_list));
}
@ -621,13 +600,11 @@ onnx.Graph = class {
// console.log(inputs)
// console.log(outputs)
console.log(node_info)
// console.log(node_info)
var attributes = []
if (schema.attributes) {
for (const attr of schema.attributes) {
console.log(attr)
var value = node_info.attributes.get(attr.name) // modified value or null
console.log(value)
attributes.push(
new onnx.LightAttributeInfo(
attr.name,
@ -638,8 +615,6 @@ onnx.Graph = class {
)
}
}
console.log(attributes[0].value)
var custom_add_node = new onnx.Node(
this._context,
node_info.properties.get('op_type'),
@ -691,18 +666,11 @@ onnx.Argument = class {
this._annotation = annotation;
this._description = description || '';
// this._renamed = false;
// this._new_name = null;
// console.log(original_name)
this.original_name = original_name || name
// console.log(this.original_name)
}
get name() {
// if (this._renamed) {
// return this._new_name;
// }
return this._name;
}
@ -868,19 +836,14 @@ onnx.Attribute = class {
// console.log(attribute)
this._value = attribute.value;
this._type = attribute.type;
// TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_added_node. This is unsafe
// TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_added_node. This is may be unsafe
// throw new onnx.Error("Unknown attribute type '" + attribute.type + "'.");
}
// console.log(attribute.type)
// console.log(this._value)
// console.log(this._type)
// see #L1294 GraphMetadata
const metadata = context.metadata.attribute(op_type, domain, attribute.name);
// console.log(metadata)
if (metadata) {
// console.log(Object.prototype.hasOwnProperty.call(metadata, 'default') && this._value == metadata.default) // false
// console.log(metadata.type === 'DataType') // false
if (Object.prototype.hasOwnProperty.call(metadata, 'default') && this._value == metadata.default) {
this._visible = false;
}
@ -1768,13 +1731,9 @@ onnx.GraphContext = class {
}
tensor(name) {
// console.log(this._tensors)
// console.log(name)
if (!this._tensors.has(name)) {
this._tensors.set(name, { name: name });
}
// console.log(this._tensors)
return this._tensors.get(name);
}
@ -1792,12 +1751,9 @@ onnx.GraphContext = class {
argument(name, original_name) {
const tensor = this.tensor(name);
// console.log(name)
// console.log(tensor)
const type = tensor.initializer ? tensor.initializer.type : tensor.type || null;
return new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name);
// // if (!this._arguments.has(name)) {
// if ((!this._arguments.has(name)) ||
// (this._arguments.has(name) && !this._arguments.get(name).original_name == original_name)
@ -1811,22 +1767,6 @@ onnx.GraphContext = class {
// return this._arguments.get(name);
}
mayChangedArgument(name, rename_map) {
// console.log(name, rename_map)
if (rename_map.get(name)) {
var arg_name = rename_map.get(name)
}
else {
var arg_name = name
}
// console.log(arg_name)
// console.log(this._arguments)
return this.argument(arg_name)
}
createType(type) {
if (!type) {
return null;
@ -1921,9 +1861,6 @@ onnx.GraphContext = class {
node.input.length === 0 &&
node.output.length === 1 && node.output[0] && inputMap.get(node.output[0].name) === 1 && outputMap.get(node.output[0].name) === 1;
const attribute = constant ? node.attribute[0] : null;
// console.log(node)
// console.log(constant) // false
// console.log(attribute) // null
if (attribute && attribute.name === 'value' && attribute.type === onnx.AttributeType.TENSOR && attribute.t) {
const tensor = this.tensor(node.output[0].name);
tensor.initializer = new onnx.Tensor(this, attribute.t, 'Constant');

@ -479,14 +479,9 @@ view.View = class {
this.UpdateAddNodeDropDown();
}
this.lastViewGraph = this._graph;
// if (this.lastViewGraph) {
// // console.log(this.lastViewGraph._addedNode)
// // console.log(this.lastViewGraph._modelNodeName2State)
// }
const graph = this.activeGraph;
// console.log(graph.nodes)
// console.log("_updateGraph is called");
return this._timeout(100).then(() => {
if (graph && graph != lastGraphs[0]) {
const nodes = graph.nodes;
@ -579,8 +574,6 @@ view.View = class {
const viewGraph = new view.Graph(this, model, groups, options);
if (this.lastViewGraph) {
// console.log(this.lastViewGraph._modelNodeName2State)
// console.log('node state of lastViewGraph is loaded')
viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State;
viewGraph._renameMap = this.lastViewGraph._renameMap;
viewGraph._changedAttributes = this.lastViewGraph._changedAttributes;
@ -898,16 +891,10 @@ view.View = class {
}
}
// console.log(this._graphs[0].nodes)
// console.log(this.lastViewGraph._addedNode)
}
// re-fresh node arguments in case the node inputs/outputs are changed
refreshNodeArguments() {
// console.log(this.lastViewGraph._renameMap)
// console.log(this._graphs[0])
// console.log(this._graphs[0]._nodes)
refreshNodeArguments() {
for (var node of this._graphs[0]._nodes) {
// this node has some changed arguments
// console.log(node)
@ -917,11 +904,6 @@ view.View = class {
// check inputs
for (var input of node.inputs) {
for (const [index, element] of input.arguments.entries()) {
// console.log(element)
// console.log(element.orig_arg_name)
// console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) {
if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) {
// console.log(element.name)
var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
@ -940,11 +922,6 @@ view.View = class {
// check outputs
for (var output of node.outputs) {
for (const [index, element] of output.arguments.entries()) {
// console.log(element)
// console.log(element.orig_arg_name)
// console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) {
if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) {
// console.log(element.name)
var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
@ -959,25 +936,13 @@ view.View = class {
}
}
// check outputs
// for (var output of node.outputs) {
// for (const [index, element] of output.arguments.entries()) {
// // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// output.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// }
// }
}
}
// console.log(this._graphs[0]._context._arguments)
for (const node_name of this.lastViewGraph._changedAttributes.keys()) {
var attr_change_map = this.lastViewGraph._changedAttributes.get(node_name)
var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name)
// console.log(node)
// console.log(attr_change_map)
// for (const attr of node._attributes) {
for (var i = 0; i < node._attributes.length; ++i) {
if (attr_change_map.get(node._attributes[i].name)) {
node._attributes[i]._value = attr_change_map.get(node._attributes[i].name)
@ -1077,55 +1042,13 @@ view.Graph = class extends grapher.Graph {
}
}
// console.log(this._renameMap)
// console.log(graph.nodes)
// console.log(this._arguments)
for (var node of graph.nodes) {
var viewNode = this.createNode(node);
var inputs = node.inputs;
for (var input of inputs) {
for (var argument of input.arguments) {
if (argument.name != '' && !argument.initializer) {
// if (viewNode.modelNodeName == "Conv3") {
// console.log("input", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name)
// console.log(graph.nodes[2]._outputs[0]._arguments[0]._name) // the linked arguments will be changed at the same time?
// console.log(graph.nodes[2]._outputs[0]._arguments[0]._new_name) // the linked arguments will be changed at the same time?
// }
// if this argument has been renamed
// if (
// this._renameMap.get(viewNode.modelNodeName) &&
// this._renameMap.get(viewNode.modelNodeName).get(argument._name) &&
// !this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name
// )
// {
// argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
// argument._renamed = true;
// }
// else { argument._renamed = false; }
// if (viewNode.modelNodeName == "Conv3") {
// console.log("input", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name)
// console.log(graph.nodes[2]._outputs[0]._arguments[0]._name) // the linked arguments will be changed at the same time?
// console.log(graph.nodes[2]._outputs[0]._arguments[0]._new_name) // the linked arguments will be changed at the same time?
// }
// if this argument has been renamed
// if (
// this._renameMap.get(viewNode.modelNodeName) &&
// this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name
// )
// {
// var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// }
// else {
// var arg_name = argument.name
// }
// this.createArgument(argument, arg_name).to(viewNode);
if (argument.name != '' && !argument.initializer) {
this.createArgument(argument).to(viewNode);
}
}
@ -1143,38 +1066,6 @@ view.Graph = class extends grapher.Graph {
throw new view.Error("Invalid null argument in '" + this.model.identifier + "'.");
}
if (argument.name != '') {
// // if this argument has been renamed
// if (
// this._renameMap.get(viewNode.modelNodeName) &&
// this._renameMap.get(viewNode.modelNodeName).get(argument._name) &&
// !this._renameMap.get(viewNode.modelNodeName).get(argument._name) == ''
// )
// {
// argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
// argument._renamed = true;
// // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name)
// }
// else { argument._renamed = false; }
// if (viewNode.modelNodeName == "MaxPool2") {
// console.log("output", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name)
// }
// if this argument has been renamed
// if (
// this._renameMap.get(viewNode.modelNodeName) &&
// this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == ''
// )
// {
// var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
// // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name)
// }
// else {
// var arg_name = argument.name
// }
// this.createArgument(argument, arg_name).from(viewNode);
this.createArgument(argument).from(viewNode);
}
}
@ -1319,15 +1210,10 @@ view.Graph = class extends grapher.Graph {
var node_id = (this._add_nodeKey++).toString(); // in case input (onnx) node has no name
var modelNodeName = 'custom_added_' + op_type + node_id
// console.log(op_type)
// console.log(modelNodeName)
var properties = new Map()
properties.set('domain', op_domain)
properties.set('op_type', op_type)
properties.set('name', modelNodeName)
// console.log(properties)
// this._addedNode.push(new view.LightNodeInfo(properties))
this._addedNode.set(modelNodeName, new view.LightNodeInfo(properties))
// console.log(this._addedNode)
@ -1365,29 +1251,17 @@ view.Graph = class extends grapher.Graph {
// this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph
}
// console.log(this._addedNode)
// else {
// if (!this._renameMap.get(modelNodeName)) {
// this._renameMap.set(modelNodeName, new Map());
// }
// this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
// // console.log(this._renameMap)
// }
else { // for the nodes in the original model
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
}
// var changed_node = this._modelNodeName2ModelNode.get(modelNodeName)
// console.log(inp_or_out, param_index, arg_index)
// console.log(changed_node)
if (inp_or_out == 'input') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
if (inp_or_out == 'output') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}

@ -67,10 +67,7 @@ def make_attr_changed_node(node, attr_change_info):
# https://github.com/onnx/onnx/blob/4e24b635c940801555bee574b4eb3a34cab9acd5/onnx/helper.py#L548
new_attr[attr.name] = onnx.helper.get_attribute_value(attr)
# print(new_attr)
# print(node.input, node.output)
# print(node)
node = onnx.helper.make_node(
op_type=node.op_type,
inputs=node.input,

Loading…
Cancel
Save