fixed shared `argument` problem

1123
ZhangGe6 4 years ago
parent f9aa8840c3
commit ffe2173b40

@ -681,7 +681,7 @@ onnx.Parameter = class {
onnx.Argument = class { onnx.Argument = class {
constructor(name, type, initializer, annotation, description) { constructor(name, type, initializer, annotation, description, original_name) {
if (typeof name !== 'string') { if (typeof name !== 'string') {
throw new onnx.Error("Invalid argument identifier '" + JSON.stringify(name) + "'."); throw new onnx.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
} }
@ -691,15 +691,16 @@ onnx.Argument = class {
this._annotation = annotation; this._annotation = annotation;
this._description = description || ''; this._description = description || '';
this._renamed = false; // this._renamed = false;
this._new_name = null; // this._new_name = null;
this.original_name = original_name || name
} }
get name() { get name() {
if (this._renamed) { // if (this._renamed) {
return this._new_name; // return this._new_name;
} // }
return this._name; return this._name;
} }
@ -1785,17 +1786,33 @@ onnx.GraphContext = class {
return this._groups.get(name); return this._groups.get(name);
} }
argument(name) { argument(name, original_name) {
if (!this._arguments.has(name)) { if (!this._arguments.has(name)) {
const tensor = this.tensor(name); const tensor = this.tensor(name);
// console.log(name) // console.log(name)
// console.log(tensor) // console.log(tensor)
const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; const type = tensor.initializer ? tensor.initializer.type : tensor.type || null;
this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description)); this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name));
} }
return this._arguments.get(name); return this._arguments.get(name);
} }
mayChangedArgument(name, rename_map) {
// console.log(name, rename_map)
if (rename_map.get(name)) {
var arg_name = rename_map.get(name)
}
else {
var arg_name = name
}
// console.log(arg_name)
// console.log(this._arguments)
return this.argument(arg_name)
}
createType(type) { createType(type) {
if (!type) { if (!type) {
return null; return null;

@ -15,6 +15,7 @@ grapher.Graph = class {
// My code // My code
this._modelNodeName2ViewNode = new Map(); this._modelNodeName2ViewNode = new Map();
this._modelNodeName2ModelNode = new Map();
this._modelNodeName2State = new Map(); this._modelNodeName2State = new Map();
this._namedEdges = new Map(); this._namedEdges = new Map();
@ -45,6 +46,7 @@ grapher.Graph = class {
const modelNodeName = node.modelNodeName const modelNodeName = node.modelNodeName
this._modelNodeName2ViewNode.set(modelNodeName, node); this._modelNodeName2ViewNode.set(modelNodeName, node);
this._modelNodeName2ModelNode.set(modelNodeName, node.value)
// _modelNodeName2State save our modifications, and wil be initilized at the first graph construction only // _modelNodeName2State save our modifications, and wil be initilized at the first graph construction only
// otherwise the modfications will lost // otherwise the modfications will lost

@ -182,16 +182,18 @@ sidebar.NodeSidebar = class {
const inputs = node.inputs; const inputs = node.inputs;
if (inputs && inputs.length > 0) { if (inputs && inputs.length > 0) {
this._addHeader('Inputs'); this._addHeader('Inputs');
for (const input of inputs) { for (const [index, input] of inputs.entries()){
this._addInput(input.name, input); // 这里的input.name是小白格前面的名称不是方格内的 // for (const input of inputs) {
this._addInput(input.name, input, index); // 这里的input.name是小白格前面的名称不是方格内的
} }
} }
const outputs = node.outputs; const outputs = node.outputs;
if (outputs && outputs.length > 0) { if (outputs && outputs.length > 0) {
this._addHeader('Outputs'); this._addHeader('Outputs');
for (const output of outputs) { for (const [index, output] of outputs.entries()){
this._addOutput(output.name, output); // for (const output of outputs) {
this._addOutput(output.name, output, index);
} }
} }
@ -297,10 +299,10 @@ sidebar.NodeSidebar = class {
this._elements.push(view.render()); this._elements.push(view.render());
} }
_addInput(name, input) { _addInput(name, input, param_idx) {
// console.log(input) // console.log(input)
if (input.arguments.length > 0) { if (input.arguments.length > 0) {
const view = new sidebar.ParameterView(this._host, input, this._modelNodeName); const view = new sidebar.ParameterView(this._host, input, 'input', param_idx, this._modelNodeName);
view.on('export-tensor', (sender, tensor) => { view.on('export-tensor', (sender, tensor) => {
this._raise('export-tensor', tensor); this._raise('export-tensor', tensor);
}); });
@ -314,9 +316,9 @@ sidebar.NodeSidebar = class {
} }
} }
_addOutput(name, output) { _addOutput(name, output, param_idx) {
if (output.arguments.length > 0) { if (output.arguments.length > 0) {
const item = new sidebar.NameValueView(this._host, name, new sidebar.ParameterView(this._host, output, this._modelNodeName)); const item = new sidebar.NameValueView(this._host, name, new sidebar.ParameterView(this._host, output, 'output', param_idx, this._modelNodeName));
this._outputs.push(item); this._outputs.push(item);
this._elements.push(item.render()); this._elements.push(item.render());
} }
@ -760,7 +762,7 @@ class NodeAttributeView {
sidebar.ParameterView = class { sidebar.ParameterView = class {
constructor(host, list, modelNodeName) { constructor(host, list, inp_or_out, param_idx, modelNodeName) {
this._host = host; this._host = host;
this._list = list; this._list = list;
this._modelNodeName = modelNodeName this._modelNodeName = modelNodeName
@ -769,8 +771,8 @@ sidebar.ParameterView = class {
// console.log(list) // console.log(list)
// for (const argument of list.arguments) { // for (const argument of list.arguments) {
for (const [index, argument] of list.arguments.entries()) { for (const [arg_idx, argument] of list.arguments.entries()) {
const item = new sidebar.ArgumentView(host, argument, index, list._name, this._modelNodeName); const item = new sidebar.ArgumentView(host, argument, inp_or_out, param_idx, arg_idx, list._name, this._modelNodeName);
item.on('export-tensor', (sender, tensor) => { item.on('export-tensor', (sender, tensor) => {
this._raise('export-tensor', tensor); this._raise('export-tensor', tensor);
}); });
@ -809,9 +811,11 @@ sidebar.ParameterView = class {
sidebar.ArgumentView = class { sidebar.ArgumentView = class {
constructor(host, argument, arg_index, parameterName, modelNodeName) { constructor(host, argument, inp_or_out, param_index, arg_index, parameterName, modelNodeName) {
this._host = host; this._host = host;
this._argument = argument; this._argument = argument;
this._inp_or_out = inp_or_out
this._param_index = param_index
this._arg_index = arg_index this._arg_index = arg_index
this._parameterName = parameterName this._parameterName = parameterName
this._modelNodeName = modelNodeName this._modelNodeName = modelNodeName
@ -862,7 +866,9 @@ sidebar.ArgumentView = class {
// console.log(this._argument) // console.log(this._argument)
// console.log(this._argument.name) // console.log(this._argument.name)
// console.log(e.target.value); // console.log(e.target.value);
this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name); // this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name);
// this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value);
this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._inp_or_out, this._param_index, this._arg_index, e.target.value);
// console.log(this._host._view._graph._renameMap); // console.log(this._host._view._graph._renameMap);
}); });
this._element.appendChild(arg_input); this._element.appendChild(arg_input);

@ -482,7 +482,7 @@ view.View = class {
// // console.log(this.lastViewGraph._addedNode) // // console.log(this.lastViewGraph._addedNode)
// } // }
const graph = this.activeGraph; const graph = this.activeGraph;
console.log(graph.nodes) // console.log(graph.nodes)
// console.log("_updateGraph is called"); // console.log("_updateGraph is called");
return this._timeout(100).then(() => { return this._timeout(100).then(() => {
@ -901,16 +901,84 @@ view.View = class {
// re-fresh node arguments in case the node inputs/outputs are changed // re-fresh node arguments in case the node inputs/outputs are changed
refreshNodeArguments() { refreshNodeArguments() {
console.log(this._renameMap) console.log(this.lastViewGraph._renameMap)
for (var node in this._graphs[0].nodes) { // console.log(this._graphs[0])
console.log(this._graphs[0]._nodes)
// for (const node_name of this.lastViewGraph._renameMap.keys()) {
// var rename_map = this.lastViewGraph._renameMap.get(node_name)
// var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name)
// console.log(node)
// console.log(rename_map)
// }
for (var node of this._graphs[0]._nodes) {
// this node has some changed arguments // this node has some changed arguments
if (this._renameMap.get(node.modelNodeName)) { // console.log(node)
// console.log(node.modelNodeName)
if (this.lastViewGraph._renameMap.get(node.modelNodeName)) {
// check inputs
for (var input of node.inputs) {
for (const [index, element] of input.arguments.entries()) {
// console.log(element)
// console.log(element.orig_arg_name)
// console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) {
if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) {
// console.log(element.name)
var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
input.arguments[index] = arg_with_new_name
// console.log(arg_with_new_name)
console.log(node)
}
}
}
// check outputs
for (var output of node.outputs) {
for (const [index, element] of output.arguments.entries()) {
// console.log(element)
// console.log(element.orig_arg_name)
// console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) {
if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) {
// console.log(element.name)
var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
output.arguments[index] = arg_with_new_name
// console.log(arg_with_new_name)
console.log(node)
}
}
}
// check outputs
// for (var output of node.outputs) {
// for (const [index, element] of output.arguments.entries()) {
// // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// output.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// }
// }
} }
} }
// console.log(this._graphs[0]._context._arguments)
} }
}; };
@ -956,22 +1024,22 @@ view.Graph = class extends grapher.Graph {
return value; return value;
} }
// createArgument(argument) { createArgument(argument) {
// const name = argument.name; const name = argument.name;
// if (!this._arguments.has(name)) { if (!this._arguments.has(name)) {
// this._arguments.set(name, new view.Argument(this, argument)); this._arguments.set(name, new view.Argument(this, argument));
// }
// return this._arguments.get(name);
// }
// for change node inputs/outputs compatibility
createArgument(argument, name) {
const arg_name = name || argument.name;
if (!this._arguments.has(arg_name)) {
this._arguments.set(arg_name, new view.Argument(this, argument));
} }
return this._arguments.get(arg_name); return this._arguments.get(name);
} }
// for change node inputs/outputs compatibility
// createArgument(argument, name) {
// const arg_name = name || argument.name;
// if (!this._arguments.has(arg_name)) {
// this._arguments.set(arg_name, new view.Argument(this, argument));
// }
// return this._arguments.get(arg_name);
// }
createEdge(from, to) { createEdge(from, to) {
const value = new view.Edge(from, to); const value = new view.Edge(from, to);
@ -1003,8 +1071,8 @@ view.Graph = class extends grapher.Graph {
} }
// console.log(this._renameMap) // console.log(this._renameMap)
console.log(graph.nodes) // console.log(graph.nodes)
console.log(this._arguments) // console.log(this._arguments)
for (var node of graph.nodes) { for (var node of graph.nodes) {
var viewNode = this.createNode(node); var viewNode = this.createNode(node);
@ -1038,20 +1106,20 @@ view.Graph = class extends grapher.Graph {
// } // }
// if this argument has been renamed // if this argument has been renamed
if ( // if (
this._renameMap.get(viewNode.modelNodeName) && // this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument._name) // this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name // // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name
) // )
{ // {
var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name) // var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name)
} // }
else { // else {
var arg_name = argument.name // var arg_name = argument.name
} // }
// this.createArgument(argument, arg_name).to(viewNode);
this.createArgument(argument).to(viewNode);
this.createArgument(argument, arg_name).to(viewNode);
} }
} }
} }
@ -1086,20 +1154,21 @@ view.Graph = class extends grapher.Graph {
// } // }
// if this argument has been renamed // if this argument has been renamed
if ( // if (
this._renameMap.get(viewNode.modelNodeName) && // this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument._name) // this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == ''
) // )
{ // {
var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name); // var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
// console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name) // // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name)
} // }
else { // else {
var arg_name = argument.name // var arg_name = argument.name
} // }
// this.createArgument(argument, arg_name).from(viewNode);
this.createArgument(argument, arg_name).from(viewNode); this.createArgument(argument).from(viewNode);
} }
} }
} }
@ -1214,7 +1283,7 @@ view.Graph = class extends grapher.Graph {
var modelNodeName = 'custom_added_' + op_type + node_id var modelNodeName = 'custom_added_' + op_type + node_id
// console.log(op_type) // console.log(op_type)
console.log(modelNodeName) // console.log(modelNodeName)
var properties = new Map() var properties = new Map()
properties.set('domain', op_domain) properties.set('domain', op_domain)
@ -1236,7 +1305,8 @@ view.Graph = class extends grapher.Graph {
} }
changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue, orig_arg_name) { changeNodeInputOutput(modelNodeName, parameterName, inp_or_out, param_index, arg_index, targetValue, orig_arg_name) {
// changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) {
if (this._addedNode.has(modelNodeName)) { // for custom added node if (this._addedNode.has(modelNodeName)) { // for custom added node
if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) {
this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue
@ -1249,13 +1319,34 @@ view.Graph = class extends grapher.Graph {
} }
// console.log(this._addedNode) // console.log(this._addedNode)
else { // for the nodes in the original model // else { // for the nodes in the original model
// if (!this._renameMap.get(modelNodeName)) {
// this._renameMap.set(modelNodeName, new Map());
// }
// this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
// // console.log(this._renameMap)
// }
else {
if (!this._renameMap.get(modelNodeName)) { if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map()); this._renameMap.set(modelNodeName, new Map());
} }
this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue); // var changed_node = this._modelNodeName2ModelNode.get(modelNodeName)
// console.log(this._renameMap) // console.log(inp_or_out, param_index, arg_index)
// console.log(changed_node)
if (inp_or_out == 'input') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
if (inp_or_out == 'output') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
console.log(this._renameMap)
} }
this.view._updateGraph() this.view._updateGraph()

Loading…
Cancel
Save