remove node for nu-named node done. refresh graph for un-named node do not work

1123
ZhangGe6 4 years ago
parent 4e69fa0d48
commit c75a50257d

@ -1,8 +1,8 @@
# https://leimao.github.io/blog/ONNX-Python-API/
# https://github.com/saurabh-shandilya/onnx-utils
#
import io
import os
from platform import node
import onnx
class onnxModifier:
@ -16,20 +16,19 @@ class onnxModifier:
def gen_node_name2module_map(self):
self.node_name2module = dict()
node_idx = 0
for node in self.graph.input:
node_idx += 1
self.node_name2module[node.name] = node
# for node in self.graph.input:
# node_idx += 1
# self.node_name2module[node.name] = node
for node in self.graph.node:
if node.name == '':
node.name = str(node.op_type) + str(node_idx)
node_idx += 1
self.node_name2module[node.name] = node
for node in self.graph.output:
node_idx += 1
self.node_name2module[node.name] = node
self.graph_output_names = [node.name for node in self.graph.output]
for out in self.graph.output:
self.node_name2module[out.name] = out
self.graph_output_names = [out.name for out in self.graph.output]
# print(self.node_name2module.keys())
@classmethod
@ -41,7 +40,6 @@ class onnxModifier:
@classmethod
def from_name_stream(cls, name, stream):
# https://leimao.github.io/blog/ONNX-IO-Stream/
stream.seek(0)
model_proto = onnx.load_model(stream, onnx.ModelProto)
return cls(name, model_proto)
@ -56,8 +54,10 @@ class onnxModifier:
for node_name, node_state in node_states.items():
if node_state == 'Deleted':
if node_name in self.graph_output_names:
# print('removing output {} ...'.format(node_name))
self.remove_output_by_name(node_name)
else:
# print('removing node {} ...'.format(node_name))
self.remove_node_by_name(node_name)
def check_and_save_model(self, save_dir='./res_onnx'):
@ -72,13 +72,23 @@ class onnxModifier:
if __name__ == "__main__":
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
onnx_modifier.remove_node_by_name('Softmax_nc_rename_64')
onnx_modifier.remove_output_by_name('softmaxout_1')
# for node in onnx_modifier.graph.node:
# print(node.name)
# for node in onnx_modifier.graph.output:
# print(node.name)
print(onnx_modifier.node_name2module.keys())
print(onnx_modifier.graph_output_names)
# onnx_modifier.remove_node_by_name('Softmax_nc_rename_64')
# onnx_modifier.remove_output_by_name('softmaxout_1')
# onnx_modifier.graph.output.remove(onnx_modifier.node_name2module['softmaxout_1'])
onnx_modifier.check_and_save_model()
# onnx_modifier.check_and_save_model()
# print(type(onnx_modifier.graph.input))
# print(type(onnx_modifier.graph.output))
# print(onnx_modifier.graph.input)
# print(onnx_modifier.graph.output)
# print(onnx_modifier.node_name2module['Softmax_nc_rename_64'])

@ -24,7 +24,7 @@ grapher.Graph = class {
}
setNode(node) {
const key = node.name;
const key = node.name; // node id
const value = this._nodes.get(key);
if (value) {
value.label = node;
@ -39,11 +39,18 @@ grapher.Graph = class {
}
// My code
const modelNodeName = node.value.name;
// var modelNodeName = node.value.name;
// if (modelNodeName == '') { // in case that model node has no name
// modelNodeName = node.value.type.name + node.name
// // console.log(node.value)
// // console.log(modelNodeName)
// }
const modelNodeName = node.modelNodeName
this._modelNodeName2ViewNode.set(modelNodeName, node);
this._modelNodeName2State.set(modelNodeName, 'Exist');
// console.log(modelNodeName)
// console.log(node.modelNodeName)
}
setEdge(edge) {
@ -60,8 +67,10 @@ grapher.Graph = class {
// My code
// _namedEdges: from : to
var from_node_name = edge.from.value.name
var to_node_name = edge.to.value.name
// var from_node_name = edge.from.value.name
// var to_node_name = edge.to.value.name
var from_node_name = edge.from.modelNodeName
var to_node_name = edge.to.modelNodeName
if (!this._namedEdges.has(from_node_name)) {
this._namedEdges.set(from_node_name, []);
}

@ -124,9 +124,10 @@ sidebar.Sidebar = class {
sidebar.NodeSidebar = class {
constructor(host, node) {
constructor(host, node, modelNodeName) {
this._host = host;
this._node = node;
this._modelNodeName = modelNodeName;
this._elements = [];
this._attributes = [];
this._inputs = [];
@ -292,18 +293,18 @@ sidebar.NodeSidebar = class {
// console.log(title)
if (title === 'Delete') {
buttonElement.addEventListener('click', () => {
this._host._view._graph.delete_node(this._node.name)
this._host._view._graph.delete_node(this._modelNodeName)
});
}
if (title === 'DeleteWithChildren') {
buttonElement.addEventListener('click', () => {
this._host._view._graph.delete_node_with_children(this._node.name)
this._host._view._graph.delete_node_with_children(this._modelNodeName)
});
}
if (title === 'Recover') {
// console.log('pressed')
buttonElement.addEventListener('click', () => {
this._host._view._graph.recover_node(this._node.name)
this._host._view._graph.recover_node(this._modelNodeName)
});
}

@ -482,10 +482,14 @@ view.View = class {
const graph = this.activeGraph;
this.lastViewGraph = this._graph;
// if (lastViewGraph) {
// console.log(this._graph)
// this.lastModelNodeName2State = lastViewGraph._modelNodeName2State;
// }
// console.log("this.lastViewGraph")
// console.log(this.lastViewGraph)
if (this.lastViewGraph) {
// console.log(this._graph)
// this.lastModelNodeName2State = lastViewGraph._modelNodeName2State;
console.log('lastViewGraph _modelNodeName2State')
console.log(this.lastViewGraph._modelNodeName2State)
}
// console.log("_updateGraph is called");
return this._timeout(100).then(() => {
@ -582,7 +586,7 @@ view.View = class {
const viewGraph = new view.Graph(this, model, groups, options);
// console.log(viewGraph)
if (this.lastViewGraph) {
// console.log(this.lastViewGraph._modelNodeName2ViewNode)
// console.log(this.lastViewGraph._modelNodeName2State)
// console.log('node state of lastViewGraph is loaded')
viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State;
}
@ -792,13 +796,13 @@ view.View = class {
}
}
showNodeProperties(node, input) {
showNodeProperties(node, input, modelNodeName) {
if (node) {
try {
// console.log(node) // 注意是onnx.Node, 不是grapher.Node所以没有update() 没有element元素
// console.log(node.element) // undefined
// node.update()
const nodeSidebar = new sidebar.NodeSidebar(this._host, node);
const nodeSidebar = new sidebar.NodeSidebar(this._host, node, modelNodeName);
nodeSidebar.on('show-documentation', (/* sender, e */) => {
this.showDocumentation(node.type);
});
@ -876,8 +880,9 @@ view.Graph = class extends grapher.Graph {
}
createNode(node) {
const value = new view.Node(this, node);
value.name = (this._nodeKey++).toString();
var node_id = (this._nodeKey++).toString(); // in case input (onnx) node has no name
const value = new view.Node(this, node, node_id);
value.name = node_id;
// value.name = node.name;
this.setNode(value);
return value;
@ -885,14 +890,16 @@ view.Graph = class extends grapher.Graph {
createInput(input) {
const value = new view.Input(this, input);
value.name = (this._nodeKey++).toString();
// value.name = (this._nodeKey++).toString();
value.name = input.name; // input nodes should have name
this.setNode(value);
return value;
}
createOutput(output) {
const value = new view.Output(this, output);
value.name = (this._nodeKey++).toString();
// value.name = (this._nodeKey++).toString();
value.name = output.name; // output nodes should have name
this.setNode(value);
return value;
}
@ -938,13 +945,16 @@ view.Graph = class extends grapher.Graph {
for (const node of graph.nodes) {
// console.log(node)
// My code
if (this._modelNodeName2State.get(node.name) == 'Deleted') {
// console.log(this._modelNodeName2State.get(node.name))
continue;
}
const viewNode = this.createNode(node);
// My code
// if (this._modelNodeName2State.get(viewNode.modelNodeName) == 'Deleted') {
// // console.log(this._modelNodeName2State.get(node.name))
// continue;
// }
const inputs = node.inputs;
for (const input of inputs) {
@ -1025,6 +1035,8 @@ view.Graph = class extends grapher.Graph {
this.createArgument(argument).to(viewOutput);
}
}
// console.log()
}
// My code
@ -1041,9 +1053,13 @@ view.Graph = class extends grapher.Graph {
delete_node_with_children(node_name) {
this.delete_backtrack(node_name);
// console.log("after deleting")
// console.log(this._modelNodeName2State)
}
delete_backtrack(node_name) {
// console.log(this._modelNodeName2ViewNode)
// console.log(node_name) // empty
this._modelNodeName2State.set(node_name, 'Deleted');
this._modelNodeName2ViewNode.get(node_name).element.style.opacity = 0.3;
@ -1057,8 +1073,6 @@ view.Graph = class extends grapher.Graph {
}
build(document, origin) {
for (const argument of this._arguments.values()) {
argument.build();
@ -1071,12 +1085,13 @@ view.Node = class extends grapher.Node {
// 这里的value是一个onnx.Node这里正在构建的是view.Node
// context 是指Graph
constructor(context, value) {
constructor(context, value, node_id) {
super();
this.context = context;
this.value = value;
view.Node.counter = view.Node.counter || 0;
this.id = 'node-' + (value.name ? 'name-' + value.name : 'id-' + (view.Node.counter++).toString());
this.modelNodeName = value.name ? value.name : value.type.name + node_id
this._add(this.value);
}
@ -1109,7 +1124,7 @@ view.Node = class extends grapher.Node {
const content = this.context.view.options.names && (node.name || node.location) ? (node.name || node.location) : type.name.split('.').pop();
const tooltip = this.context.view.options.names && (node.name || node.location) ? type.name : (node.name || node.location);
const title = header.add(null, styles, content, tooltip);
title.on('click', () => this.context.view.showNodeProperties(node, null));
title.on('click', () => this.context.view.showNodeProperties(node, null, this.modelNodeName));
if (node.type.nodes && node.type.nodes.length > 0) {
const definition = header.add(null, styles, '\u0192', 'Show Function Definition');
definition.on('click', () => this.context.view.pushGraph(node.type));
@ -1145,7 +1160,7 @@ view.Node = class extends grapher.Node {
});
if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) {
const list = this.list();
list.on('click', () => this.context.view.showNodeProperties(node));
list.on('click', () => this.context.view.showNodeProperties(node, null, this.modelNodeName));
for (const initializer of initializers) {
const argument = initializer.arguments[0];
const type = argument.type;
@ -1228,6 +1243,7 @@ view.Input = class extends grapher.Node {
view.Input.counter = view.Input.counter || 0;
const types = value.arguments.map((argument) => argument.type || '').join('\n');
let name = value.name || '';
this.modelNodeName = value.name
if (name.length > 16) {
name = name.split('/').pop();
}
@ -1258,6 +1274,7 @@ view.Output = class extends grapher.Node {
this.value = value;
const types = value.arguments.map((argument) => argument.type || '').join('\n');
let name = value.name || '';
this.modelNodeName = value.name
if (name.length > 16) {
name = name.split('/').pop();
}

Loading…
Cancel
Save