diff --git a/onnx_modifier.py b/onnx_modifier.py index d55b93c..1341050 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -1,8 +1,8 @@ # https://leimao.github.io/blog/ONNX-Python-API/ # https://github.com/saurabh-shandilya/onnx-utils +# import io import os -from platform import node import onnx class onnxModifier: @@ -16,20 +16,19 @@ class onnxModifier: def gen_node_name2module_map(self): self.node_name2module = dict() node_idx = 0 - for node in self.graph.input: - node_idx += 1 - self.node_name2module[node.name] = node + # for node in self.graph.input: + # node_idx += 1 + # self.node_name2module[node.name] = node for node in self.graph.node: if node.name == '': node.name = str(node.op_type) + str(node_idx) - node_idx += 1 - self.node_name2module[node.name] = node - - for node in self.graph.output: node_idx += 1 self.node_name2module[node.name] = node - self.graph_output_names = [node.name for node in self.graph.output] + + for out in self.graph.output: + self.node_name2module[out.name] = out + self.graph_output_names = [out.name for out in self.graph.output] # print(self.node_name2module.keys()) @classmethod @@ -41,7 +40,6 @@ class onnxModifier: @classmethod def from_name_stream(cls, name, stream): # https://leimao.github.io/blog/ONNX-IO-Stream/ - stream.seek(0) model_proto = onnx.load_model(stream, onnx.ModelProto) return cls(name, model_proto) @@ -56,8 +54,10 @@ class onnxModifier: for node_name, node_state in node_states.items(): if node_state == 'Deleted': if node_name in self.graph_output_names: + # print('removing output {} ...'.format(node_name)) self.remove_output_by_name(node_name) else: + # print('removing node {} ...'.format(node_name)) self.remove_node_by_name(node_name) def check_and_save_model(self, save_dir='./res_onnx'): @@ -72,13 +72,23 @@ class onnxModifier: if __name__ == "__main__": - model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) - onnx_modifier.remove_node_by_name('Softmax_nc_rename_64') - onnx_modifier.remove_output_by_name('softmaxout_1') + # for node in onnx_modifier.graph.node: + # print(node.name) + # for node in onnx_modifier.graph.output: + # print(node.name) + print(onnx_modifier.node_name2module.keys()) + print(onnx_modifier.graph_output_names) + + # onnx_modifier.remove_node_by_name('Softmax_nc_rename_64') + # onnx_modifier.remove_output_by_name('softmaxout_1') # onnx_modifier.graph.output.remove(onnx_modifier.node_name2module['softmaxout_1']) - onnx_modifier.check_and_save_model() + # onnx_modifier.check_and_save_model() + # print(type(onnx_modifier.graph.input)) + # print(type(onnx_modifier.graph.output)) # print(onnx_modifier.graph.input) # print(onnx_modifier.graph.output) # print(onnx_modifier.node_name2module['Softmax_nc_rename_64']) diff --git a/static/view-grapher.js b/static/view-grapher.js index e216b12..10a6fb4 100644 --- a/static/view-grapher.js +++ b/static/view-grapher.js @@ -24,7 +24,7 @@ grapher.Graph = class { } setNode(node) { - const key = node.name; + const key = node.name; // node id const value = this._nodes.get(key); if (value) { value.label = node; @@ -39,11 +39,18 @@ grapher.Graph = class { } // My code - const modelNodeName = node.value.name; + // var modelNodeName = node.value.name; + // if (modelNodeName == '') { // in case that model node has no name + // modelNodeName = node.value.type.name + node.name + // // console.log(node.value) + // // console.log(modelNodeName) + // } + const modelNodeName = node.modelNodeName this._modelNodeName2ViewNode.set(modelNodeName, node); this._modelNodeName2State.set(modelNodeName, 'Exist'); // console.log(modelNodeName) + // console.log(node.modelNodeName) } setEdge(edge) { @@ -60,8 +67,10 @@ grapher.Graph = class { // My code // _namedEdges: from : to - var from_node_name = edge.from.value.name - var to_node_name = edge.to.value.name + // var from_node_name = edge.from.value.name + // var to_node_name = edge.to.value.name + var from_node_name = edge.from.modelNodeName + var to_node_name = edge.to.modelNodeName if (!this._namedEdges.has(from_node_name)) { this._namedEdges.set(from_node_name, []); } diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 84cfe06..e55b795 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -124,9 +124,10 @@ sidebar.Sidebar = class { sidebar.NodeSidebar = class { - constructor(host, node) { + constructor(host, node, modelNodeName) { this._host = host; this._node = node; + this._modelNodeName = modelNodeName; this._elements = []; this._attributes = []; this._inputs = []; @@ -292,18 +293,18 @@ sidebar.NodeSidebar = class { // console.log(title) if (title === 'Delete') { buttonElement.addEventListener('click', () => { - this._host._view._graph.delete_node(this._node.name) + this._host._view._graph.delete_node(this._modelNodeName) }); } if (title === 'DeleteWithChildren') { buttonElement.addEventListener('click', () => { - this._host._view._graph.delete_node_with_children(this._node.name) + this._host._view._graph.delete_node_with_children(this._modelNodeName) }); } if (title === 'Recover') { // console.log('pressed') buttonElement.addEventListener('click', () => { - this._host._view._graph.recover_node(this._node.name) + this._host._view._graph.recover_node(this._modelNodeName) }); } diff --git a/static/view.js b/static/view.js index dc62c19..c3125bc 100644 --- a/static/view.js +++ b/static/view.js @@ -482,10 +482,14 @@ view.View = class { const graph = this.activeGraph; this.lastViewGraph = this._graph; - // if (lastViewGraph) { - // console.log(this._graph) - // this.lastModelNodeName2State = lastViewGraph._modelNodeName2State; - // } + // console.log("this.lastViewGraph") + // console.log(this.lastViewGraph) + if (this.lastViewGraph) { + // console.log(this._graph) + // this.lastModelNodeName2State = lastViewGraph._modelNodeName2State; + console.log('lastViewGraph _modelNodeName2State') + console.log(this.lastViewGraph._modelNodeName2State) + } // console.log("_updateGraph is called"); return this._timeout(100).then(() => { @@ -582,7 +586,7 @@ view.View = class { const viewGraph = new view.Graph(this, model, groups, options); // console.log(viewGraph) if (this.lastViewGraph) { - // console.log(this.lastViewGraph._modelNodeName2ViewNode) + // console.log(this.lastViewGraph._modelNodeName2State) // console.log('node state of lastViewGraph is loaded') viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State; } @@ -792,13 +796,13 @@ view.View = class { } } - showNodeProperties(node, input) { + showNodeProperties(node, input, modelNodeName) { if (node) { try { // console.log(node) // 注意是onnx.Node, 不是grapher.Node,所以没有update(), 没有element元素 // console.log(node.element) // undefined // node.update() - const nodeSidebar = new sidebar.NodeSidebar(this._host, node); + const nodeSidebar = new sidebar.NodeSidebar(this._host, node, modelNodeName); nodeSidebar.on('show-documentation', (/* sender, e */) => { this.showDocumentation(node.type); }); @@ -876,8 +880,9 @@ view.Graph = class extends grapher.Graph { } createNode(node) { - const value = new view.Node(this, node); - value.name = (this._nodeKey++).toString(); + var node_id = (this._nodeKey++).toString(); // in case input (onnx) node has no name + const value = new view.Node(this, node, node_id); + value.name = node_id; // value.name = node.name; this.setNode(value); return value; @@ -885,14 +890,16 @@ view.Graph = class extends grapher.Graph { createInput(input) { const value = new view.Input(this, input); - value.name = (this._nodeKey++).toString(); + // value.name = (this._nodeKey++).toString(); + value.name = input.name; // input nodes should have name this.setNode(value); return value; } createOutput(output) { const value = new view.Output(this, output); - value.name = (this._nodeKey++).toString(); + // value.name = (this._nodeKey++).toString(); + value.name = output.name; // output nodes should have name this.setNode(value); return value; } @@ -938,13 +945,16 @@ view.Graph = class extends grapher.Graph { for (const node of graph.nodes) { // console.log(node) - // My code if (this._modelNodeName2State.get(node.name) == 'Deleted') { // console.log(this._modelNodeName2State.get(node.name)) continue; } - const viewNode = this.createNode(node); + // My code + // if (this._modelNodeName2State.get(viewNode.modelNodeName) == 'Deleted') { + // // console.log(this._modelNodeName2State.get(node.name)) + // continue; + // } const inputs = node.inputs; for (const input of inputs) { @@ -1025,6 +1035,8 @@ view.Graph = class extends grapher.Graph { this.createArgument(argument).to(viewOutput); } } + + // console.log() } // My code @@ -1041,9 +1053,13 @@ view.Graph = class extends grapher.Graph { delete_node_with_children(node_name) { this.delete_backtrack(node_name); + // console.log("after deleting") + // console.log(this._modelNodeName2State) } delete_backtrack(node_name) { + // console.log(this._modelNodeName2ViewNode) + // console.log(node_name) // empty this._modelNodeName2State.set(node_name, 'Deleted'); this._modelNodeName2ViewNode.get(node_name).element.style.opacity = 0.3; @@ -1057,8 +1073,6 @@ view.Graph = class extends grapher.Graph { } - - build(document, origin) { for (const argument of this._arguments.values()) { argument.build(); @@ -1071,12 +1085,13 @@ view.Node = class extends grapher.Node { // 这里的value是一个onnx.Node,这里正在构建的是view.Node // context 是指Graph - constructor(context, value) { + constructor(context, value, node_id) { super(); this.context = context; this.value = value; view.Node.counter = view.Node.counter || 0; this.id = 'node-' + (value.name ? 'name-' + value.name : 'id-' + (view.Node.counter++).toString()); + this.modelNodeName = value.name ? value.name : value.type.name + node_id this._add(this.value); } @@ -1109,7 +1124,7 @@ view.Node = class extends grapher.Node { const content = this.context.view.options.names && (node.name || node.location) ? (node.name || node.location) : type.name.split('.').pop(); const tooltip = this.context.view.options.names && (node.name || node.location) ? type.name : (node.name || node.location); const title = header.add(null, styles, content, tooltip); - title.on('click', () => this.context.view.showNodeProperties(node, null)); + title.on('click', () => this.context.view.showNodeProperties(node, null, this.modelNodeName)); if (node.type.nodes && node.type.nodes.length > 0) { const definition = header.add(null, styles, '\u0192', 'Show Function Definition'); definition.on('click', () => this.context.view.pushGraph(node.type)); @@ -1145,7 +1160,7 @@ view.Node = class extends grapher.Node { }); if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) { const list = this.list(); - list.on('click', () => this.context.view.showNodeProperties(node)); + list.on('click', () => this.context.view.showNodeProperties(node, null, this.modelNodeName)); for (const initializer of initializers) { const argument = initializer.arguments[0]; const type = argument.type; @@ -1228,6 +1243,7 @@ view.Input = class extends grapher.Node { view.Input.counter = view.Input.counter || 0; const types = value.arguments.map((argument) => argument.type || '').join('\n'); let name = value.name || ''; + this.modelNodeName = value.name if (name.length > 16) { name = name.split('/').pop(); } @@ -1258,6 +1274,7 @@ view.Output = class extends grapher.Node { this.value = value; const types = value.arguments.map((argument) => argument.type || '').join('\n'); let name = value.name || ''; + this.modelNodeName = value.name if (name.length > 16) { name = name.split('/').pop(); }