fixed shared `argument` problem

1123
ZhangGe6 4 years ago
parent f9aa8840c3
commit ffe2173b40

@ -681,7 +681,7 @@ onnx.Parameter = class {
onnx.Argument = class {
constructor(name, type, initializer, annotation, description) {
constructor(name, type, initializer, annotation, description, original_name) {
if (typeof name !== 'string') {
throw new onnx.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
}
@ -691,15 +691,16 @@ onnx.Argument = class {
this._annotation = annotation;
this._description = description || '';
this._renamed = false;
this._new_name = null;
// this._renamed = false;
// this._new_name = null;
this.original_name = original_name || name
}
get name() {
if (this._renamed) {
return this._new_name;
}
// if (this._renamed) {
// return this._new_name;
// }
return this._name;
}
@ -1785,17 +1786,33 @@ onnx.GraphContext = class {
return this._groups.get(name);
}
argument(name) {
argument(name, original_name) {
if (!this._arguments.has(name)) {
const tensor = this.tensor(name);
// console.log(name)
// console.log(tensor)
const type = tensor.initializer ? tensor.initializer.type : tensor.type || null;
this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description));
this._arguments.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name));
}
return this._arguments.get(name);
}
mayChangedArgument(name, rename_map) {
// console.log(name, rename_map)
if (rename_map.get(name)) {
var arg_name = rename_map.get(name)
}
else {
var arg_name = name
}
// console.log(arg_name)
// console.log(this._arguments)
return this.argument(arg_name)
}
createType(type) {
if (!type) {
return null;

@ -15,6 +15,7 @@ grapher.Graph = class {
// My code
this._modelNodeName2ViewNode = new Map();
this._modelNodeName2ModelNode = new Map();
this._modelNodeName2State = new Map();
this._namedEdges = new Map();
@ -45,6 +46,7 @@ grapher.Graph = class {
const modelNodeName = node.modelNodeName
this._modelNodeName2ViewNode.set(modelNodeName, node);
this._modelNodeName2ModelNode.set(modelNodeName, node.value)
// _modelNodeName2State save our modifications, and wil be initilized at the first graph construction only
// otherwise the modfications will lost

@ -182,16 +182,18 @@ sidebar.NodeSidebar = class {
const inputs = node.inputs;
if (inputs && inputs.length > 0) {
this._addHeader('Inputs');
for (const input of inputs) {
this._addInput(input.name, input); // 这里的input.name是小白格前面的名称不是方格内的
for (const [index, input] of inputs.entries()){
// for (const input of inputs) {
this._addInput(input.name, input, index); // 这里的input.name是小白格前面的名称不是方格内的
}
}
const outputs = node.outputs;
if (outputs && outputs.length > 0) {
this._addHeader('Outputs');
for (const output of outputs) {
this._addOutput(output.name, output);
for (const [index, output] of outputs.entries()){
// for (const output of outputs) {
this._addOutput(output.name, output, index);
}
}
@ -297,10 +299,10 @@ sidebar.NodeSidebar = class {
this._elements.push(view.render());
}
_addInput(name, input) {
_addInput(name, input, param_idx) {
// console.log(input)
if (input.arguments.length > 0) {
const view = new sidebar.ParameterView(this._host, input, this._modelNodeName);
const view = new sidebar.ParameterView(this._host, input, 'input', param_idx, this._modelNodeName);
view.on('export-tensor', (sender, tensor) => {
this._raise('export-tensor', tensor);
});
@ -314,9 +316,9 @@ sidebar.NodeSidebar = class {
}
}
_addOutput(name, output) {
_addOutput(name, output, param_idx) {
if (output.arguments.length > 0) {
const item = new sidebar.NameValueView(this._host, name, new sidebar.ParameterView(this._host, output, this._modelNodeName));
const item = new sidebar.NameValueView(this._host, name, new sidebar.ParameterView(this._host, output, 'output', param_idx, this._modelNodeName));
this._outputs.push(item);
this._elements.push(item.render());
}
@ -760,7 +762,7 @@ class NodeAttributeView {
sidebar.ParameterView = class {
constructor(host, list, modelNodeName) {
constructor(host, list, inp_or_out, param_idx, modelNodeName) {
this._host = host;
this._list = list;
this._modelNodeName = modelNodeName
@ -769,8 +771,8 @@ sidebar.ParameterView = class {
// console.log(list)
// for (const argument of list.arguments) {
for (const [index, argument] of list.arguments.entries()) {
const item = new sidebar.ArgumentView(host, argument, index, list._name, this._modelNodeName);
for (const [arg_idx, argument] of list.arguments.entries()) {
const item = new sidebar.ArgumentView(host, argument, inp_or_out, param_idx, arg_idx, list._name, this._modelNodeName);
item.on('export-tensor', (sender, tensor) => {
this._raise('export-tensor', tensor);
});
@ -809,9 +811,11 @@ sidebar.ParameterView = class {
sidebar.ArgumentView = class {
constructor(host, argument, arg_index, parameterName, modelNodeName) {
constructor(host, argument, inp_or_out, param_index, arg_index, parameterName, modelNodeName) {
this._host = host;
this._argument = argument;
this._inp_or_out = inp_or_out
this._param_index = param_index
this._arg_index = arg_index
this._parameterName = parameterName
this._modelNodeName = modelNodeName
@ -862,7 +866,9 @@ sidebar.ArgumentView = class {
// console.log(this._argument)
// console.log(this._argument.name)
// console.log(e.target.value);
this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name);
// this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name);
// this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value);
this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._inp_or_out, this._param_index, this._arg_index, e.target.value);
// console.log(this._host._view._graph._renameMap);
});
this._element.appendChild(arg_input);

@ -482,7 +482,7 @@ view.View = class {
// // console.log(this.lastViewGraph._addedNode)
// }
const graph = this.activeGraph;
console.log(graph.nodes)
// console.log(graph.nodes)
// console.log("_updateGraph is called");
return this._timeout(100).then(() => {
@ -901,16 +901,84 @@ view.View = class {
// re-fresh node arguments in case the node inputs/outputs are changed
refreshNodeArguments() {
console.log(this._renameMap)
for (var node in this._graphs[0].nodes) {
console.log(this.lastViewGraph._renameMap)
// console.log(this._graphs[0])
console.log(this._graphs[0]._nodes)
// for (const node_name of this.lastViewGraph._renameMap.keys()) {
// var rename_map = this.lastViewGraph._renameMap.get(node_name)
// var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name)
// console.log(node)
// console.log(rename_map)
// }
for (var node of this._graphs[0]._nodes) {
// this node has some changed arguments
if (this._renameMap.get(node.modelNodeName)) {
// console.log(node)
// console.log(node.modelNodeName)
if (this.lastViewGraph._renameMap.get(node.modelNodeName)) {
// check inputs
for (var input of node.inputs) {
for (const [index, element] of input.arguments.entries()) {
// console.log(element)
// console.log(element.orig_arg_name)
// console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) {
if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) {
// console.log(element.name)
var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
input.arguments[index] = arg_with_new_name
// console.log(arg_with_new_name)
console.log(node)
}
}
}
// check outputs
for (var output of node.outputs) {
for (const [index, element] of output.arguments.entries()) {
// console.log(element)
// console.log(element.orig_arg_name)
// console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// input.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.name)) {
if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) {
// console.log(element.name)
var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
output.arguments[index] = arg_with_new_name
// console.log(arg_with_new_name)
console.log(node)
}
}
}
// check outputs
// for (var output of node.outputs) {
// for (const [index, element] of output.arguments.entries()) {
// // console.log(this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName)))
// output.arguments[index] = this._graphs[0]._context.mayChangedArgument(element.name, this.lastViewGraph._renameMap.get(node.modelNodeName))
// }
// }
}
}
// console.log(this._graphs[0]._context._arguments)
}
};
@ -956,22 +1024,22 @@ view.Graph = class extends grapher.Graph {
return value;
}
// createArgument(argument) {
// const name = argument.name;
// if (!this._arguments.has(name)) {
// this._arguments.set(name, new view.Argument(this, argument));
// }
// return this._arguments.get(name);
// }
// for change node inputs/outputs compatibility
createArgument(argument, name) {
const arg_name = name || argument.name;
if (!this._arguments.has(arg_name)) {
this._arguments.set(arg_name, new view.Argument(this, argument));
createArgument(argument) {
const name = argument.name;
if (!this._arguments.has(name)) {
this._arguments.set(name, new view.Argument(this, argument));
}
return this._arguments.get(arg_name);
return this._arguments.get(name);
}
// for change node inputs/outputs compatibility
// createArgument(argument, name) {
// const arg_name = name || argument.name;
// if (!this._arguments.has(arg_name)) {
// this._arguments.set(arg_name, new view.Argument(this, argument));
// }
// return this._arguments.get(arg_name);
// }
createEdge(from, to) {
const value = new view.Edge(from, to);
@ -1003,8 +1071,8 @@ view.Graph = class extends grapher.Graph {
}
// console.log(this._renameMap)
console.log(graph.nodes)
console.log(this._arguments)
// console.log(graph.nodes)
// console.log(this._arguments)
for (var node of graph.nodes) {
var viewNode = this.createNode(node);
@ -1038,20 +1106,20 @@ view.Graph = class extends grapher.Graph {
// }
// if this argument has been renamed
if (
this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name
)
{
var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name)
}
else {
var arg_name = argument.name
}
// if (
// this._renameMap.get(viewNode.modelNodeName) &&
// this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name
// )
// {
// var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// }
// else {
// var arg_name = argument.name
// }
// this.createArgument(argument, arg_name).to(viewNode);
this.createArgument(argument, arg_name).to(viewNode);
this.createArgument(argument).to(viewNode);
}
}
}
@ -1086,20 +1154,21 @@ view.Graph = class extends grapher.Graph {
// }
// if this argument has been renamed
if (
this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == ''
)
{
var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
// console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name)
}
else {
var arg_name = argument.name
}
// if (
// this._renameMap.get(viewNode.modelNodeName) &&
// this._renameMap.get(viewNode.modelNodeName).get(argument._name)
// // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == ''
// )
// {
// var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
// // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name)
// }
// else {
// var arg_name = argument.name
// }
// this.createArgument(argument, arg_name).from(viewNode);
this.createArgument(argument, arg_name).from(viewNode);
this.createArgument(argument).from(viewNode);
}
}
}
@ -1214,7 +1283,7 @@ view.Graph = class extends grapher.Graph {
var modelNodeName = 'custom_added_' + op_type + node_id
// console.log(op_type)
console.log(modelNodeName)
// console.log(modelNodeName)
var properties = new Map()
properties.set('domain', op_domain)
@ -1236,7 +1305,8 @@ view.Graph = class extends grapher.Graph {
}
changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue, orig_arg_name) {
changeNodeInputOutput(modelNodeName, parameterName, inp_or_out, param_index, arg_index, targetValue, orig_arg_name) {
// changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) {
if (this._addedNode.has(modelNodeName)) { // for custom added node
if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) {
this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue
@ -1249,13 +1319,34 @@ view.Graph = class extends grapher.Graph {
}
// console.log(this._addedNode)
else { // for the nodes in the original model
// else { // for the nodes in the original model
// if (!this._renameMap.get(modelNodeName)) {
// this._renameMap.set(modelNodeName, new Map());
// }
// this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
// // console.log(this._renameMap)
// }
else {
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
}
this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
// console.log(this._renameMap)
// var changed_node = this._modelNodeName2ModelNode.get(modelNodeName)
// console.log(inp_or_out, param_index, arg_index)
// console.log(changed_node)
if (inp_or_out == 'input') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
if (inp_or_out == 'output') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].name
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
console.log(this._renameMap)
}
this.view._updateGraph()

Loading…
Cancel
Save