try to support renaming model input/output reported in issue https://github.com/ZhangGe6/onnx-modifier/issues/8

1123
ZhangGe6 4 years ago
parent b34157e345
commit 13e00113fb

@ -22,7 +22,7 @@ def open_model():
@app.route('/download', methods=['POST']) @app.route('/download', methods=['POST'])
def modify_and_download_model(): def modify_and_download_model():
modify_info = request.get_json() modify_info = request.get_json()
print(modify_info) # print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times onnx_modifier.reload() # allow downloading for multiple times

@ -114,7 +114,9 @@ class onnxModifier:
node.name = dst_name node.name = dst_name
# print(node.name) # print(node.name)
# print(node) # print(node)
pass # pass
elif node_name in self.graph_output_names:
node.name = dst_name
else: else:
# print(node.input, node.output) # print(node.input, node.output)
for i in range(len(node.input)): for i in range(len(node.input)):

@ -1119,13 +1119,16 @@ sidebar.ModelSidebar = class {
this._addHeader('Inputs'); this._addHeader('Inputs');
// for (const input of graph.inputs) { // for (const input of graph.inputs) {
for (const [index, input] of graph.inputs.entries()){ for (const [index, input] of graph.inputs.entries()){
this.addArgument(input.name, input, index); this.addArgument(input.name, input, index, 'model_input');
// this.addArgument(input.modelNodeName, input, index, 'model_input');
} }
} }
if (Array.isArray(graph.outputs) && graph.outputs.length > 0) { if (Array.isArray(graph.outputs) && graph.outputs.length > 0) {
this._addHeader('Outputs'); this._addHeader('Outputs');
for (const output of graph.outputs) { // for (const output of graph.outputs) {
this.addArgument(output.name, output); for (const [index, output] of graph.outputs.entries()){
this.addArgument(output.name, output, index, 'model_output');
// this.addArgument(output.modelNodeName, output, index, 'model_output');
} }
} }
} }
@ -1151,9 +1154,9 @@ sidebar.ModelSidebar = class {
this._elements.push(item.render()); this._elements.push(item.render());
} }
addArgument(name, argument, index) { addArgument(name, argument, index, arg_type) {
// const view = new sidebar.ParameterView(this._host, argument); // const view = new sidebar.ParameterView(this._host, argument);
const view = new sidebar.ParameterView(this._host, argument, 'model_input', index, name); const view = new sidebar.ParameterView(this._host, argument, arg_type, index, name);
view.toggle(); view.toggle();
const item = new sidebar.NameValueView(this._host, name, view); const item = new sidebar.NameValueView(this._host, name, view);
this._elements.push(item.render()); this._elements.push(item.render());

@ -963,12 +963,12 @@ view.View = class {
for (var input of this._graphs[0]._inputs) { for (var input of this._graphs[0]._inputs) {
// console.log(input) // console.log(input)
// console.log(input.modelNodeName) // console.log(input.modelNodeName)
var input_orig_name = input.arguments[0].original_name
if (this.lastViewGraph._renameMap.get(input.modelNodeName)) { if (this.lastViewGraph._renameMap.get(input_orig_name)) {
// for model input and output, node.modelNodeName == element.original_name // for model input and output, node.modelNodeName == element.original_name
var new_name = this.lastViewGraph._renameMap.get(input.modelNodeName).get(input.modelNodeName) var new_name = this.lastViewGraph._renameMap.get(input_orig_name).get(input_orig_name)
// console.log(new_name) // console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, input.modelNodeName) var arg_with_new_name = this._graphs[0]._context.argument(new_name, input_orig_name)
input.arguments[0] = arg_with_new_name input.arguments[0] = arg_with_new_name
@ -982,7 +982,7 @@ view.View = class {
for (const [index, element] of node_input.arguments.entries()) { for (const [index, element] of node_input.arguments.entries()) {
// console.log(element.name, input.modelNodeName) // console.log(element.name, input.modelNodeName)
// if (element.name == input.modelNodeName) { // if (element.name == input.modelNodeName) {
if (element.original_name == input.modelNodeName) { if (element.original_name == input_orig_name) {
// console.log(element.name) // console.log(element.name)
// var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) // var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
// console.log(element.original_name) // console.log(element.original_name)
@ -1009,6 +1009,62 @@ view.View = class {
} }
}
}
for (var output of this._graphs[0]._outputs) {
// console.log(output)
// console.log(this.lastViewGraph._renameMap)
// console.log(input.modelNodeName)
// if (this.lastViewGraph._renameMap.get(output.modelNodeName)) {
var output_orig_name = output.arguments[0].original_name
if (this.lastViewGraph._renameMap.get('out_' + output_orig_name)) {
// for model input and output, node.modelNodeName == element.original_name
var new_name = this.lastViewGraph._renameMap.get('out_' + output_orig_name).get(output_orig_name)
console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, output_orig_name)
output.arguments[0] = arg_with_new_name
// change all the name of node input linked with model input meanwhile
for (var node of this._graphs[0]._nodes) {
// this node has some changed arguments
// console.log(node)
// console.log(node.modelNodeName)
// if (this.lastViewGraph._renameMap.get(node.modelNodeName)) {
for (var node_output of node.outputs) {
for (const [index, element] of node_output.arguments.entries()) {
// console.log(element.name, input.modelNodeName)
// if (element.name == input.modelNodeName) {
// console.log(element.original_name, output.modelNodeName)
if (element.original_name == output_orig_name) {
console.log(element.original_name)
// var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
// console.log(element.original_name)
// console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
node_output.arguments[index] = arg_with_new_name
// save the changed name into _renameMap
// as this modified _renamedMap, so refreshModelInputOutput() shoulf be called before refreshNodeArguments()
if (!this.lastViewGraph._renameMap.get(node.modelNodeName)) {
this.lastViewGraph._renameMap.set(node.modelNodeName, new Map());
}
var orig_arg_name = element.original_name
this.lastViewGraph._renameMap.get(node.modelNodeName).set(orig_arg_name, new_name);
// console.log(arg_with_new_name)
// console.log(node)
}
}
}
// }
}
} }
} }
} }
@ -1040,9 +1096,17 @@ view.Graph = class extends grapher.Graph {
} }
createInput(input) { createInput(input) {
const value = new view.Input(this, input); if (this._renameMap.get(input.name)) {
var show_name = this._renameMap.get(input.name).get(input.name);
}
else {
var show_name = input.name; // input nodes should have name
}
const value = new view.Input(this, input, show_name);
// value.name = (this._nodeKey++).toString(); // value.name = (this._nodeKey++).toString();
value.name = input.name; // input nodes should have name
value.name = input.name;
// console.log(value.name)
input.modelNodeName = input.name; input.modelNodeName = input.name;
this.setNode(value); this.setNode(value);
return value; return value;
@ -1050,7 +1114,13 @@ view.Graph = class extends grapher.Graph {
createOutput(output) { createOutput(output) {
var modelNodeName = "out_" + output.name; // in case the output has the same name with the last node var modelNodeName = "out_" + output.name; // in case the output has the same name with the last node
const value = new view.Output(this, output, modelNodeName); if (this._renameMap.get(modelNodeName)) {
var show_name = this._renameMap.get(modelNodeName).get(output.name);
}
else {
var show_name = output.name; // input nodes should have name
}
const value = new view.Output(this, output, modelNodeName, show_name);
// value.name = (this._nodeKey++).toString(); // value.name = (this._nodeKey++).toString();
value.name = "out_" + output.name; // output nodes should have name value.name = "out_" + output.name; // output nodes should have name
output.modelNodeName = "out_" + output.name; output.modelNodeName = "out_" + output.name;
@ -1326,19 +1396,29 @@ view.Graph = class extends grapher.Graph {
// console.log(this._addedNode) // console.log(this._addedNode)
else { // for the nodes in the original model else { // for the nodes in the original model
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
}
if (param_type == 'model_input' || param_type == 'model_output') { if (param_type == 'model_input') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[arg_index].orig_arg_name // var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[arg_index].orig_arg_name
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments) // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments)
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0]) // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0])
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].orig_arg_name) // console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].orig_arg_name)
// console.log("changing model_input", orig_arg_name) // console.log("changing model_input", orig_arg_name)
var orig_arg_name = modelNodeName // var orig_arg_name = modelNodeName
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name
// console.log("changing model_input", orig_arg_name) // console.log("changing model_input", orig_arg_name)
// console.log(param_type, orig_arg_name)
}
if (param_type == 'model_output') {
// console.log(this._modelNodeName2ModelNode.get('out_' + modelNodeName))
// console.log(this._modelNodeName2ModelNode.get('out_' + modelNodeName).arguments[0].original_name)
modelNodeName = 'out_' + modelNodeName
console.log(modelNodeName)
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name
console.log(orig_arg_name)
// console.log("changing model_input", orig_arg_name)
// console.log(param_type, orig_arg_name)
} }
if (param_type == 'input') { if (param_type == 'input') {
@ -1350,6 +1430,9 @@ view.Graph = class extends grapher.Graph {
// console.log(orig_arg_name) // console.log(orig_arg_name)
} }
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
}
this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue); this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
console.log(this._renameMap) console.log(this._renameMap)
} }
@ -1523,13 +1606,14 @@ view.Node = class extends grapher.Node {
view.Input = class extends grapher.Node { view.Input = class extends grapher.Node {
constructor(context, value) { constructor(context, value, show_name) {
super(); super();
this.context = context; this.context = context;
this.value = value; this.value = value;
view.Input.counter = view.Input.counter || 0; view.Input.counter = view.Input.counter || 0;
const types = value.arguments.map((argument) => argument.type || '').join('\n'); const types = value.arguments.map((argument) => argument.type || '').join('\n');
let name = value.name || ''; // let name = value.name || '';
let name = show_name
this.modelNodeName = value.name this.modelNodeName = value.name
if (name.length > 16) { if (name.length > 16) {
name = name.split('/').pop(); name = name.split('/').pop();
@ -1555,12 +1639,13 @@ view.Input = class extends grapher.Node {
view.Output = class extends grapher.Node { view.Output = class extends grapher.Node {
constructor(context, value, modelNodeName) { constructor(context, value, modelNodeName, show_name) {
super(); super();
this.context = context; this.context = context;
this.value = value; this.value = value;
const types = value.arguments.map((argument) => argument.type || '').join('\n'); const types = value.arguments.map((argument) => argument.type || '').join('\n');
let name = value.name || ''; // let name = value.name || '';
let name = show_name;
this.modelNodeName = modelNodeName this.modelNodeName = modelNodeName
if (name.length > 16) { if (name.length > 16) {
name = name.split('/').pop(); name = name.split('/').pop();

Loading…
Cancel
Save