try to support renaming model input/output reported in issue https://github.com/ZhangGe6/onnx-modifier/issues/8

1123
ZhangGe6 4 years ago
parent b34157e345
commit 13e00113fb

@ -22,7 +22,7 @@ def open_model():
@app.route('/download', methods=['POST'])
def modify_and_download_model():
modify_info = request.get_json()
print(modify_info)
# print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times

@ -114,7 +114,9 @@ class onnxModifier:
node.name = dst_name
# print(node.name)
# print(node)
pass
# pass
elif node_name in self.graph_output_names:
node.name = dst_name
else:
# print(node.input, node.output)
for i in range(len(node.input)):

@ -1119,13 +1119,16 @@ sidebar.ModelSidebar = class {
this._addHeader('Inputs');
// for (const input of graph.inputs) {
for (const [index, input] of graph.inputs.entries()){
this.addArgument(input.name, input, index);
this.addArgument(input.name, input, index, 'model_input');
// this.addArgument(input.modelNodeName, input, index, 'model_input');
}
}
if (Array.isArray(graph.outputs) && graph.outputs.length > 0) {
this._addHeader('Outputs');
for (const output of graph.outputs) {
this.addArgument(output.name, output);
// for (const output of graph.outputs) {
for (const [index, output] of graph.outputs.entries()){
this.addArgument(output.name, output, index, 'model_output');
// this.addArgument(output.modelNodeName, output, index, 'model_output');
}
}
}
@ -1151,9 +1154,9 @@ sidebar.ModelSidebar = class {
this._elements.push(item.render());
}
addArgument(name, argument, index) {
addArgument(name, argument, index, arg_type) {
// const view = new sidebar.ParameterView(this._host, argument);
const view = new sidebar.ParameterView(this._host, argument, 'model_input', index, name);
const view = new sidebar.ParameterView(this._host, argument, arg_type, index, name);
view.toggle();
const item = new sidebar.NameValueView(this._host, name, view);
this._elements.push(item.render());

@ -963,12 +963,12 @@ view.View = class {
for (var input of this._graphs[0]._inputs) {
// console.log(input)
// console.log(input.modelNodeName)
if (this.lastViewGraph._renameMap.get(input.modelNodeName)) {
var input_orig_name = input.arguments[0].original_name
if (this.lastViewGraph._renameMap.get(input_orig_name)) {
// for model input and output, node.modelNodeName == element.original_name
var new_name = this.lastViewGraph._renameMap.get(input.modelNodeName).get(input.modelNodeName)
var new_name = this.lastViewGraph._renameMap.get(input_orig_name).get(input_orig_name)
// console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, input.modelNodeName)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, input_orig_name)
input.arguments[0] = arg_with_new_name
@ -982,7 +982,7 @@ view.View = class {
for (const [index, element] of node_input.arguments.entries()) {
// console.log(element.name, input.modelNodeName)
// if (element.name == input.modelNodeName) {
if (element.original_name == input.modelNodeName) {
if (element.original_name == input_orig_name) {
// console.log(element.name)
// var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
// console.log(element.original_name)
@ -1009,6 +1009,62 @@ view.View = class {
}
}
}
for (var output of this._graphs[0]._outputs) {
// console.log(output)
// console.log(this.lastViewGraph._renameMap)
// console.log(input.modelNodeName)
// if (this.lastViewGraph._renameMap.get(output.modelNodeName)) {
var output_orig_name = output.arguments[0].original_name
if (this.lastViewGraph._renameMap.get('out_' + output_orig_name)) {
// for model input and output, node.modelNodeName == element.original_name
var new_name = this.lastViewGraph._renameMap.get('out_' + output_orig_name).get(output_orig_name)
console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, output_orig_name)
output.arguments[0] = arg_with_new_name
// change all the name of node input linked with model input meanwhile
for (var node of this._graphs[0]._nodes) {
// this node has some changed arguments
// console.log(node)
// console.log(node.modelNodeName)
// if (this.lastViewGraph._renameMap.get(node.modelNodeName)) {
for (var node_output of node.outputs) {
for (const [index, element] of node_output.arguments.entries()) {
// console.log(element.name, input.modelNodeName)
// if (element.name == input.modelNodeName) {
// console.log(element.original_name, output.modelNodeName)
if (element.original_name == output_orig_name) {
console.log(element.original_name)
// var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
// console.log(element.original_name)
// console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
node_output.arguments[index] = arg_with_new_name
// save the changed name into _renameMap
// as this modified _renamedMap, so refreshModelInputOutput() shoulf be called before refreshNodeArguments()
if (!this.lastViewGraph._renameMap.get(node.modelNodeName)) {
this.lastViewGraph._renameMap.set(node.modelNodeName, new Map());
}
var orig_arg_name = element.original_name
this.lastViewGraph._renameMap.get(node.modelNodeName).set(orig_arg_name, new_name);
// console.log(arg_with_new_name)
// console.log(node)
}
}
}
// }
}
}
}
}
@ -1040,9 +1096,17 @@ view.Graph = class extends grapher.Graph {
}
createInput(input) {
const value = new view.Input(this, input);
if (this._renameMap.get(input.name)) {
var show_name = this._renameMap.get(input.name).get(input.name);
}
else {
var show_name = input.name; // input nodes should have name
}
const value = new view.Input(this, input, show_name);
// value.name = (this._nodeKey++).toString();
value.name = input.name; // input nodes should have name
value.name = input.name;
// console.log(value.name)
input.modelNodeName = input.name;
this.setNode(value);
return value;
@ -1050,7 +1114,13 @@ view.Graph = class extends grapher.Graph {
createOutput(output) {
var modelNodeName = "out_" + output.name; // in case the output has the same name with the last node
const value = new view.Output(this, output, modelNodeName);
if (this._renameMap.get(modelNodeName)) {
var show_name = this._renameMap.get(modelNodeName).get(output.name);
}
else {
var show_name = output.name; // input nodes should have name
}
const value = new view.Output(this, output, modelNodeName, show_name);
// value.name = (this._nodeKey++).toString();
value.name = "out_" + output.name; // output nodes should have name
output.modelNodeName = "out_" + output.name;
@ -1326,19 +1396,29 @@ view.Graph = class extends grapher.Graph {
// console.log(this._addedNode)
else { // for the nodes in the original model
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
}
if (param_type == 'model_input' || param_type == 'model_output') {
if (param_type == 'model_input') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[arg_index].orig_arg_name
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments)
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0])
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].orig_arg_name)
// console.log("changing model_input", orig_arg_name)
var orig_arg_name = modelNodeName
// var orig_arg_name = modelNodeName
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name
// console.log("changing model_input", orig_arg_name)
// console.log(param_type, orig_arg_name)
}
if (param_type == 'model_output') {
// console.log(this._modelNodeName2ModelNode.get('out_' + modelNodeName))
// console.log(this._modelNodeName2ModelNode.get('out_' + modelNodeName).arguments[0].original_name)
modelNodeName = 'out_' + modelNodeName
console.log(modelNodeName)
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name
console.log(orig_arg_name)
// console.log("changing model_input", orig_arg_name)
// console.log(param_type, orig_arg_name)
}
if (param_type == 'input') {
@ -1350,6 +1430,9 @@ view.Graph = class extends grapher.Graph {
// console.log(orig_arg_name)
}
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
}
this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
console.log(this._renameMap)
}
@ -1523,13 +1606,14 @@ view.Node = class extends grapher.Node {
view.Input = class extends grapher.Node {
constructor(context, value) {
constructor(context, value, show_name) {
super();
this.context = context;
this.value = value;
view.Input.counter = view.Input.counter || 0;
const types = value.arguments.map((argument) => argument.type || '').join('\n');
let name = value.name || '';
// let name = value.name || '';
let name = show_name
this.modelNodeName = value.name
if (name.length > 16) {
name = name.split('/').pop();
@ -1555,12 +1639,13 @@ view.Input = class extends grapher.Node {
view.Output = class extends grapher.Node {
constructor(context, value, modelNodeName) {
constructor(context, value, modelNodeName, show_name) {
super();
this.context = context;
this.value = value;
const types = value.arguments.map((argument) => argument.type || '').join('\n');
let name = value.name || '';
// let name = value.name || '';
let name = show_name;
this.modelNodeName = modelNodeName
if (name.length > 16) {
name = name.split('/').pop();

Loading…
Cancel
Save