this is a backup before making some challenging changes for refresh node arguments

1123
ZhangGe6 4 years ago
parent d190aebf91
commit f9aa8840c3

@ -13,7 +13,6 @@ add node (add preprocess nodes): https://zhuanlan.zhihu.com/p/394395167
topk: https://github.com/onnx/onnx/issues/2921 topk: https://github.com/onnx/onnx/issues/2921
# done # done
remove layer: https://github.com/onnx/onnx/issues/2638 remove layer: https://github.com/onnx/onnx/issues/2638
@ -22,4 +21,13 @@ remove layer: https://github.com/onnx/onnx/issues/2638
# 或许可以帮助 # 或许可以帮助
http://yyixx.com/docs/algo/onnx/ http://yyixx.com/docs/algo/onnx/
# 待做的
bug(fixed): 不可连续添加某一种类型的节点(无反应)
boost: 直接使用侧边栏inputs/outputs属性框完成重命名并提供reset功能
boost: 支持处理属性的修改
boost: 支持添加更复杂的节点
question: 在add()函数里为什么对conv的inputs进行遍历只能得到X而得不到W和B

@ -521,7 +521,7 @@ onnx.Graph = class {
return 'graph(' + this.name + ')'; return 'graph(' + this.name + ')';
} }
make_custom_add_node(node_info) { make_custom_added_node(node_info) {
// type of node_info == LightNodeInfo // type of node_info == LightNodeInfo
const schema = this._context.metadata.type(node_info.properties.get('op_type'), node_info.properties.get('domain')); const schema = this._context.metadata.type(node_info.properties.get('op_type'), node_info.properties.get('domain'));
// console.log(schema) // console.log(schema)
@ -865,7 +865,7 @@ onnx.Attribute = class {
// console.log(attribute) // console.log(attribute)
this._value = attribute.value; this._value = attribute.value;
this._type = attribute.type; this._type = attribute.type;
// TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_add_node. This is unsafe // TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_added_node. This is unsafe
// throw new onnx.Error("Unknown attribute type '" + attribute.type + "'."); // throw new onnx.Error("Unknown attribute type '" + attribute.type + "'.");
} }
// console.log(attribute.type) // console.log(attribute.type)

@ -862,7 +862,7 @@ sidebar.ArgumentView = class {
// console.log(this._argument) // console.log(this._argument)
// console.log(this._argument.name) // console.log(this._argument.name)
// console.log(e.target.value); // console.log(e.target.value);
this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value); this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name);
// console.log(this._host._view._graph._renameMap); // console.log(this._host._view._graph._renameMap);
}); });
this._element.appendChild(arg_input); this._element.appendChild(arg_input);

@ -461,6 +461,7 @@ view.View = class {
var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null;
if (active_graph && this.lastViewGraph) { if (active_graph && this.lastViewGraph) {
this.refreshAddedNode() this.refreshAddedNode()
this.refreshNodeArguments()
} }
return active_graph return active_graph
@ -481,8 +482,7 @@ view.View = class {
// // console.log(this.lastViewGraph._addedNode) // // console.log(this.lastViewGraph._addedNode)
// } // }
const graph = this.activeGraph; const graph = this.activeGraph;
// console.log(graph.nodes) console.log(graph.nodes)
// console.log("_updateGraph is called"); // console.log("_updateGraph is called");
return this._timeout(100).then(() => { return this._timeout(100).then(() => {
@ -582,6 +582,7 @@ view.View = class {
viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State; viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State;
viewGraph._renameMap = this.lastViewGraph._renameMap; viewGraph._renameMap = this.lastViewGraph._renameMap;
viewGraph._addedNode = this.lastViewGraph._addedNode; viewGraph._addedNode = this.lastViewGraph._addedNode;
viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey
// console.log(viewGraph._renameMap); // console.log(viewGraph._renameMap);
// console.log(viewGraph._modelNodeName2State) // console.log(viewGraph._modelNodeName2State)
} }
@ -867,12 +868,13 @@ view.View = class {
} }
// re-generate the added node according to _addedNode
refreshAddedNode() { refreshAddedNode() {
this._graphs[0].reset_custom_added_node() this._graphs[0].reset_custom_added_node()
// for (const node_info of this._addedNode.values()) { // for (const node_info of this._addedNode.values()) {
for (const [modelNodeName, node_info] of this.lastViewGraph._addedNode) { for (const [modelNodeName, node_info] of this.lastViewGraph._addedNode) {
// console.log(node_info) // console.log(node_info)
var node = this._graphs[0].make_custom_add_node(node_info) var node = this._graphs[0].make_custom_added_node(node_info)
// console.log(node) // console.log(node)
for (const input of node.inputs) { for (const input of node.inputs) {
@ -897,6 +899,20 @@ view.View = class {
// console.log(this.lastViewGraph._addedNode) // console.log(this.lastViewGraph._addedNode)
} }
// re-fresh node arguments in case the node inputs/outputs are changed
refreshNodeArguments() {
console.log(this._renameMap)
for (var node in this._graphs[0].nodes) {
// this node has some changed arguments
if (this._renameMap.get(node.modelNodeName)) {
}
}
}
}; };
view.Graph = class extends grapher.Graph { view.Graph = class extends grapher.Graph {
@ -915,6 +931,7 @@ view.Graph = class extends grapher.Graph {
createNode(node) { createNode(node) {
var node_id = (this._nodeKey++).toString(); // in case input (onnx) node has no name var node_id = (this._nodeKey++).toString(); // in case input (onnx) node has no name
var modelNodeName = node.name ? node.name : node.type.name + node_id var modelNodeName = node.name ? node.name : node.type.name + node_id
node.modelNodeName = modelNodeName // this will take in-place effect for the onnx.Node in onnx.Graph, which can make it more convenient if we want to find a node in onnx.Graph later
const value = new view.Node(this, node, modelNodeName); const value = new view.Node(this, node, modelNodeName);
value.name = node_id; value.name = node_id;
@ -939,12 +956,21 @@ view.Graph = class extends grapher.Graph {
return value; return value;
} }
createArgument(argument) { // createArgument(argument) {
const name = argument.name; // const name = argument.name;
if (!this._arguments.has(name)) { // if (!this._arguments.has(name)) {
this._arguments.set(name, new view.Argument(this, argument)); // this._arguments.set(name, new view.Argument(this, argument));
// }
// return this._arguments.get(name);
// }
// for change node inputs/outputs compatibility
createArgument(argument, name) {
const arg_name = name || argument.name;
if (!this._arguments.has(arg_name)) {
this._arguments.set(arg_name, new view.Argument(this, argument));
} }
return this._arguments.get(name); return this._arguments.get(arg_name);
} }
createEdge(from, to) { createEdge(from, to) {
@ -976,26 +1002,56 @@ view.Graph = class extends grapher.Graph {
} }
} }
for (const node of graph.nodes) { // console.log(this._renameMap)
const viewNode = this.createNode(node); console.log(graph.nodes)
console.log(this._arguments)
for (var node of graph.nodes) {
var viewNode = this.createNode(node);
const inputs = node.inputs; var inputs = node.inputs;
for (const input of inputs) { for (var input of inputs) {
for (const argument of input.arguments) { for (var argument of input.arguments) {
if (argument.name != '' && !argument.initializer) { if (argument.name != '' && !argument.initializer) {
// if (viewNode.modelNodeName == "Conv3") {
// console.log("input", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name)
// console.log(graph.nodes[2]._outputs[0]._arguments[0]._name) // the linked arguments will be changed at the same time?
// console.log(graph.nodes[2]._outputs[0]._arguments[0]._new_name) // the linked arguments will be changed at the same time?
// }
// if this argument has been renamed
// if (
// this._renameMap.get(viewNode.modelNodeName) &&
// this._renameMap.get(viewNode.modelNodeName).get(argument._name) &&
// !this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name
// )
// {
// argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
// argument._renamed = true;
// }
// else { argument._renamed = false; }
// if (viewNode.modelNodeName == "Conv3") {
// console.log("input", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name)
// console.log(graph.nodes[2]._outputs[0]._arguments[0]._name) // the linked arguments will be changed at the same time?
// console.log(graph.nodes[2]._outputs[0]._arguments[0]._new_name) // the linked arguments will be changed at the same time?
// }
// if this argument has been renamed // if this argument has been renamed
if ( if (
this._renameMap.get(viewNode.modelNodeName) && this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument.name) && this._renameMap.get(viewNode.modelNodeName).get(argument._name)
!this._renameMap.get(viewNode.modelNodeName).get(argument.name) == '' // in case user clear the input name // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == '' // in case user clear the input name
) )
{ {
argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument.name); var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name)
argument._renamed = true;
} }
else { argument._renamed = false; } else {
var arg_name = argument.name
}
this.createArgument(argument).to(viewNode); this.createArgument(argument, arg_name).to(viewNode);
} }
} }
} }
@ -1006,25 +1062,44 @@ view.Graph = class extends grapher.Graph {
outputs = chainOutputs; outputs = chainOutputs;
} }
} }
for (const output of outputs) { for (var output of outputs) {
for (const argument of output.arguments) { for (var argument of output.arguments) {
if (!argument) { if (!argument) {
throw new view.Error("Invalid null argument in '" + this.model.identifier + "'."); throw new view.Error("Invalid null argument in '" + this.model.identifier + "'.");
} }
if (argument.name != '') { if (argument.name != '') {
// // if this argument has been renamed
// if (
// this._renameMap.get(viewNode.modelNodeName) &&
// this._renameMap.get(viewNode.modelNodeName).get(argument._name) &&
// !this._renameMap.get(viewNode.modelNodeName).get(argument._name) == ''
// )
// {
// argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
// argument._renamed = true;
// // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name)
// }
// else { argument._renamed = false; }
// if (viewNode.modelNodeName == "MaxPool2") {
// console.log("output", this._renameMap, viewNode.modelNodeName, argument._name, argument._renamed, argument.name)
// }
// if this argument has been renamed // if this argument has been renamed
if ( if (
this._renameMap.get(viewNode.modelNodeName) && this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument.name) && this._renameMap.get(viewNode.modelNodeName).get(argument._name)
!this._renameMap.get(viewNode.modelNodeName).get(argument.name) == '' // &&!this._renameMap.get(viewNode.modelNodeName).get(argument._name) == ''
) )
{ {
argument._new_name = this._renameMap.get(viewNode.modelNodeName).get(argument.name); var arg_name = this._renameMap.get(viewNode.modelNodeName).get(argument._name);
argument._renamed = true; // console.log("the output of ", viewNode.modelNodeName, "is renamed", argument._renamed, argument.name)
}
else {
var arg_name = argument.name
} }
else { argument._renamed = false; }
this.createArgument(argument).from(viewNode); this.createArgument(argument, arg_name).from(viewNode);
} }
} }
} }
@ -1134,15 +1209,12 @@ view.Graph = class extends grapher.Graph {
} }
add_node(op_domain, op_type) { add_node(op_domain, op_type) {
// node_name: the added op name
// parent_node_name: parent modelNodeName
// console.log(node_name)
var node_id = (this._add_nodeKey++).toString(); // in case input (onnx) node has no name var node_id = (this._add_nodeKey++).toString(); // in case input (onnx) node has no name
var modelNodeName = 'custom_added_' + op_type + node_id var modelNodeName = 'custom_added_' + op_type + node_id
// console.log(op_type) // console.log(op_type)
// console.log(modelNodeName) console.log(modelNodeName)
var properties = new Map() var properties = new Map()
properties.set('domain', op_domain) properties.set('domain', op_domain)
@ -1164,8 +1236,8 @@ view.Graph = class extends grapher.Graph {
} }
changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue, orig_arg_name) {
if (this._addedNode.has(modelNodeName)) { if (this._addedNode.has(modelNodeName)) { // for custom added node
if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) {
this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue
} }
@ -1173,9 +1245,21 @@ view.Graph = class extends grapher.Graph {
if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) { if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) {
this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue
} }
this.view._updateGraph() // otherwise the changes can not be updated without manully update graph // this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph
} }
// console.log(this._addedNode) // console.log(this._addedNode)
else { // for the nodes in the original model
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
}
this._renameMap.get(modelNodeName).set(orig_arg_name, targetValue);
// console.log(this._renameMap)
}
this.view._updateGraph()
} }

Loading…
Cancel
Save