From 95f18e37b0cb04e73b5060af6fcd5850b8af27f9 Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Sun, 12 Jun 2022 15:50:46 +0800 Subject: [PATCH] the feature of `change node attribute` is basically done --- onnx_modifier.py | 47 ++++++++++++++++++++++++++++++-------- static/index.js | 7 +++--- static/view-grapher.js | 1 + static/view-sidebar.js | 3 ++- static/view.js | 39 +++++++++++++++++++++---------- utils/make_nodes.py | 52 +++++++++++++++++++++++++++++++++++++++++- 6 files changed, 122 insertions(+), 27 deletions(-) diff --git a/onnx_modifier.py b/onnx_modifier.py index c480fba..6b310cf 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -8,7 +8,7 @@ import copy import numpy as np import onnx import onnxruntime as rt -from utils import make_node +from utils import make_new_node, make_attr_changed_node class onnxModifier: def __init__(self, model_name, model_proto): @@ -115,14 +115,25 @@ class onnxModifier: node.input[i] = dst_name for i in range(len(node.output)): if node.output[i] == src_name: - node.output[i] = dst_name - + node.output[i] = dst_name + def modify_node_attr(self, node_changed_attr): + # we achieve it by deleting the original node and make a (copied) new node + # print(node_changed_attr) + for node_name in node_changed_attr.keys(): + orig_node = self.node_name2module[node_name] + attr_changed_node = make_attr_changed_node(orig_node, node_changed_attr[node_name]) + self.graph.node.remove(self.node_name2module[node_name]) + self.graph.node.append(attr_changed_node) + + # update the node_name2module and initializer_name2module + self.gen_name2module_map() + def add_node(self, nodes_info, node_states): for node_info in nodes_info.values(): if node_states[node_info['properties']['name']] == "Deleted": continue - node = make_node(node_info) + node = make_new_node(node_info) # print(node) self.graph.node.append(node) @@ -131,17 +142,23 @@ class onnxModifier: def modify(self, modify_info): # print(modify_info['node_states']) # print(modify_info['node_renamed_io']) + print(modify_info['node_changed_attr']) # print(modify_info['added_node_info']) self.remove_node_by_node_states(modify_info['node_states']) - self.modify_node_io_name(modify_info['node_renamed_io']) - self.add_node(modify_info['added_node_info'], modify_info['node_states']) + self.modify_node_io_name(modify_info['node_renamed_io']) + self.modify_node_attr(modify_info['node_changed_attr']) + self.add_node(modify_info['added_node_info'], modify_info['node_states']) + def check_and_save_model(self, save_dir='./modified_onnx'): if not os.path.exists(save_dir): os.mkdir(save_dir) - save_path = os.path.join(save_dir, 'modified_' + self.model_name) - onnx.checker.check_model(self.model_proto) + + # adding new node like self.add_node() and self.modify_node_attr() can not guarantee the nodes are topologically sorted + # so `onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted` will be invoked + # I turn off the onnx checker as a workaround. + # onnx.checker.check_model(self.model_proto) onnx.save(self.model_proto, save_path) def inference(self, x=None, output_names=None): @@ -241,7 +258,7 @@ if __name__ == "__main__": # print(inp.name) onnx_modifier.check_and_save_model() - remove_node_by_node_states() + # remove_node_by_node_states() def test_modify_node_io_name(): @@ -257,7 +274,17 @@ if __name__ == "__main__": onnx_modifier.inference() onnx_modifier.check_and_save_model() - # test_add_node() + + def test_change_node_attr(): + # changed_attr = {'Clip_3': {'max': 5}} + changed_attr = {'Conv_2': {'group': 64}} + + onnx_modifier.modify_node_attr(changed_attr) + + onnx_modifier.check_and_save_model() + + test_change_node_attr() + \ No newline at end of file diff --git a/static/index.js b/static/index.js index 0aff9d8..fb5ea9a 100644 --- a/static/index.js +++ b/static/index.js @@ -215,9 +215,9 @@ host.BrowserHost = class { const downloadButton = this.document.getElementById('download-graph'); downloadButton.addEventListener('click', () => { - console.log(this._view._graph._addedNode) - console.log(this._view._graph._renameMap) - // https://healeycodes.com/talking-between-languages + // console.log(this._view._graph._addedNode) + // console.log(this._view._graph._renameMap) + // // https://healeycodes.com/talking-between-languages fetch('/download', { // Declare what type of data we're sending headers: { @@ -228,6 +228,7 @@ host.BrowserHost = class { body: JSON.stringify({ 'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State), 'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap), + 'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes), 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)) } diff --git a/static/view-grapher.js b/static/view-grapher.js index 39eaf58..5e8f253 100644 --- a/static/view-grapher.js +++ b/static/view-grapher.js @@ -21,6 +21,7 @@ grapher.Graph = class { this._pathArgumentNames = new Set(); // the name of arguments which occurs in both sides of an edge this._renameMap = new Map(); + this._changedAttributes = new Map(); this._addedNode = new Map(); } diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 0d13b5d..8f3ba11 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -658,7 +658,7 @@ class NodeAttributeView { this._element.appendChild(this._expander); } const value = this._attribute.value; - // console.log(this._attribute.name, value, type) + console.log(this._attribute.name, value, type) switch (type) { case 'graph': { const line = this._host.document.createElement('div'); @@ -699,6 +699,7 @@ class NodeAttributeView { attr_input.setAttribute("value", content ? content : 'undefined'); attr_input.addEventListener('input', (e) => { // console.log(e.target.value); + // console.log(this.parse_value(e.target.value, type)) this._host._view._graph.changeNodeAttribute(this._modelNodeName, this._attributeName, this.parse_value(e.target.value, type)); // console.log(this._host._view._graph._renameMap); }); diff --git a/static/view.js b/static/view.js index 6427d61..f963356 100644 --- a/static/view.js +++ b/static/view.js @@ -462,6 +462,7 @@ view.View = class { if (active_graph && this.lastViewGraph) { this.refreshAddedNode() this.refreshNodeArguments() + } return active_graph @@ -582,6 +583,7 @@ view.View = class { // console.log('node state of lastViewGraph is loaded') viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State; viewGraph._renameMap = this.lastViewGraph._renameMap; + viewGraph._changedAttributes = this.lastViewGraph._changedAttributes; viewGraph._addedNode = this.lastViewGraph._addedNode; viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey // console.log(viewGraph._renameMap); @@ -904,16 +906,7 @@ view.View = class { refreshNodeArguments() { // console.log(this.lastViewGraph._renameMap) // console.log(this._graphs[0]) - // console.log(this._graphs[0]._nodes) - - // for (const node_name of this.lastViewGraph._renameMap.keys()) { - // var rename_map = this.lastViewGraph._renameMap.get(node_name) - // var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name) - // console.log(node) - // console.log(rename_map) - // } - - + // console.log(this._graphs[0]._nodes) for (var node of this._graphs[0]._nodes) { // this node has some changed arguments @@ -978,8 +971,20 @@ view.View = class { // console.log(this._graphs[0]._context._arguments) + for (const node_name of this.lastViewGraph._changedAttributes.keys()) { + var attr_change_map = this.lastViewGraph._changedAttributes.get(node_name) + var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name) + // console.log(node) + // console.log(attr_change_map) + // for (const attr of node._attributes) { + for (var i = 0; i < node._attributes.length; ++i) { + if (attr_change_map.get(node._attributes[i].name)) { + node._attributes[i]._value = attr_change_map.get(node._attributes[i].name) + } + } + } } @@ -1330,6 +1335,16 @@ view.Graph = class extends grapher.Graph { this._addedNode.get(modelNodeName).attributes.set(attributeName, targetValue) } // console.log(this._addedNode) + + else { // for the nodes in the original model + if (!this._changedAttributes.get(modelNodeName)) { + this._changedAttributes.set(modelNodeName, new Map()); + } + this._changedAttributes.get(modelNodeName).set(attributeName, targetValue) + // console.log(this._changedAttributes) + + } + this.view._updateGraph() } @@ -1348,7 +1363,7 @@ view.Graph = class extends grapher.Graph { } // console.log(this._addedNode) - // else { // for the nodes in the original model + // else { // if (!this._renameMap.get(modelNodeName)) { // this._renameMap.set(modelNodeName, new Map()); // } @@ -1356,7 +1371,7 @@ view.Graph = class extends grapher.Graph { // // console.log(this._renameMap) // } - else { + else { // for the nodes in the original model if (!this._renameMap.get(modelNodeName)) { this._renameMap.set(modelNodeName, new Map()); } diff --git a/utils/make_nodes.py b/utils/make_nodes.py index e0b0a4f..6aa3ab1 100644 --- a/utils/make_nodes.py +++ b/utils/make_nodes.py @@ -1,6 +1,8 @@ import onnx +from onnx import AttributeProto -def make_node(node_info): + +def make_new_node(node_info): name = node_info['properties']['name'] op_type = node_info['properties']['op_type'] # attributes = node_info['attributes'] @@ -31,4 +33,52 @@ def make_node(node_info): # print(node) + return node + +def make_attr_changed_node(node, attr_change_info): + # convert the changed attribute value into the type that is consistent with the original attribute + # because AttributeProto is constructed barely based on the input value + # https://github.com/onnx/onnx/blob/4e24b635c940801555bee574b4eb3a34cab9acd5/onnx/helper.py#L472 + def make_type_value(value, AttributeProto_type): + # https://github.com/protocolbuffers/protobuf/blob/main/python/google/protobuf/internal/enum_type_wrapper.py#L60 + attr_type = AttributeProto.AttributeType.Name(AttributeProto_type) + if attr_type == "FLOAT": + return float(value) + elif attr_type == "INT": + return int(value) + elif attr_type == "STRING": + return str(value) + elif attr_type == "FLOATS": + return [float(v) for v in value] + elif attr_type == "INTS": + return [int(v) for v in value] + elif attr_type == "STRINGS": + return [str(v) for v in value] + else: + raise RuntimeError("type {} is not considered in current version. \ + You can kindly report an issue for this problem. Thanks!".format(attr_type)) + + new_attr = dict() + for attr in node.attribute: + # print(onnx.helper.get_attribute_value(attr)) + if attr.name in attr_change_info.keys(): + new_attr[attr.name] = make_type_value(attr_change_info[attr.name], attr.type) + else: + # https://github.com/onnx/onnx/blob/4e24b635c940801555bee574b4eb3a34cab9acd5/onnx/helper.py#L548 + new_attr[attr.name] = onnx.helper.get_attribute_value(attr) + # print(new_attr) + # print(node.input, node.output) + + # print(node) + + node = onnx.helper.make_node( + op_type=node.op_type, + inputs=node.input, + outputs=node.output, + name=node.name, + **new_attr + ) + + # print(node) + return node \ No newline at end of file