From d289f00f96f97f77fc77fdcd18d0701255b03a78 Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Tue, 25 Oct 2022 12:18:02 +0800 Subject: [PATCH] support editing initializer for primitive nodes (https://github.com/ZhangGe6/onnx-modifier/issues/6, https://github.com/ZhangGe6/onnx-modifier/issues/9, https://github.com/ZhangGe6/onnx-modifier/issues/22) --- app.py | 2 +- onnx_modifier.py | 83 +++++++++++++++++++++++++---------------- static/index.js | 3 +- static/onnx.js | 12 +++++- static/view-grapher.js | 1 + static/view-sidebar.css | 1 + static/view-sidebar.js | 36 +++++++++++++++--- static/view.js | 74 +++++++++++++++++++++++++----------- utils/__init__.py | 3 +- utils/parse_tools.py | 70 ++++++++++++++++++++++++++++++++++ 10 files changed, 221 insertions(+), 64 deletions(-) create mode 100644 utils/parse_tools.py diff --git a/app.py b/app.py index 18645a1..8410518 100644 --- a/app.py +++ b/app.py @@ -22,7 +22,7 @@ def open_model(): @app.route('/download', methods=['POST']) def modify_and_download_model(): modify_info = request.get_json() - # print(modify_info) + print(modify_info) onnx_modifier.reload() # allow downloading for multiple times onnx_modifier.modify(modify_info) onnx_modifier.check_and_save_model() diff --git a/onnx_modifier.py b/onnx_modifier.py index fefb67f..034e033 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -9,7 +9,8 @@ import copy import struct import numpy as np import onnx -from utils import make_new_node, make_attr_changed_node +from onnx import numpy_helper +from utils import make_new_node, make_attr_changed_node, parse_tensor class onnxModifier: def __init__(self, model_name, model_proto): @@ -100,7 +101,7 @@ class onnxModifier: # remove node in graph self.graph.node.remove(self.node_name2module[node_name]) - def remove_output_by_name(self, node_name): + def remove_model_output_by_name(self, node_name): self.graph.output.remove(self.node_name2module[node_name]) def remove_node_by_node_states(self, node_states): @@ -112,7 +113,7 @@ class onnxModifier: if node_state == 'Deleted': if node_name in self.graph_output_names: # print('removing output {} ...'.format(node_name)) - self.remove_output_by_name(node_name) + self.remove_model_output_by_name(node_name) else: # print('removing node {} ...'.format(node_name)) self.remove_node_by_name(node_name) @@ -131,24 +132,11 @@ class onnxModifier: for input_name in self.graph_input_names: if not input_name in left_node_inputs: self.graph.input.remove(self.node_name2module[input_name]) - - # remove the left unused Constant nodes - for left_node in self.graph.node: - if left_node.op_type == "Constant": - output_deleted = [False] * len(left_node.output) - for i, output in enumerate(left_node.output): - if not (output in left_node_inputs): - output_deleted[i] = True - - const_node_left_output = [left_node.output[i] for i in range(len(left_node.output)) if not output_deleted[i]] - if len(const_node_left_output) == 0: - self.graph.node.remove(self.node_name2module[left_node.name]) - # self.initializer.remove(self.initializer_name2module[init_name]) - + def modify_node_io_name(self, node_renamed_io): for node_name in node_renamed_io.keys(): if node_name not in self.node_name2module.keys(): - # custom added nodes or custom added model outputs + # custom added nodes or custom added model outputs, or the deleted nodes continue renamed_ios = node_renamed_io[node_name] for src_name, dst_name in renamed_ios.items(): @@ -164,7 +152,14 @@ class onnxModifier: node.input[i] = dst_name for i in range(len(node.output)): if node.output[i] == src_name: - node.output[i] = dst_name + node.output[i] = dst_name + + # TODO: rename the corresponding initializer and update initializer_name2module + if src_name in self.initializer_name2module.keys(): + init = self.initializer_name2module[src_name] + init.name = dst_name + self.initializer_name2module[dst_name] = init + del self.initializer_name2module[src_name] def modify_node_attr(self, node_changed_attr): # we achieve it by deleting the original node and make a (copied) new node @@ -178,7 +173,7 @@ class onnxModifier: # update the node_name2module and initializer_name2module self.gen_name2module_map() - def add_node(self, nodes_info, node_states): + def add_nodes(self, nodes_info, node_states): for node_info in nodes_info.values(): if node_states[node_info['properties']['name']] == "Deleted": continue @@ -199,17 +194,33 @@ class onnxModifier: value_info_protos.append(value_info) self.graph.output.extend(value_info_protos) + def modify_initializer(self, changed_initializer): + for init_name, meta in changed_initializer.items(): + # https://github.com/onnx/onnx/issues/2978 + init_type, init_val_str = meta + # print(init_name, init_type, init_val) + init_val = parse_tensor(init_val_str, init_type) + # print(init_val) + tensor = numpy_helper.from_array(init_val, init_name) + self.initializer_name2module[init_name].CopyFrom(tensor) + def modify(self, modify_info): + ''' + Some functions, such as modify_initializer(), should be placed + before modify_node_io_name(), to avoid name mismatch error. + ''' # print(modify_info['node_states']) # print(modify_info['node_renamed_io']) # print(modify_info['node_changed_attr']) # print(modify_info['added_node_info']) # print(modify_info['added_outputs']) + + self.modify_initializer(modify_info['changed_initializer']) self.change_batch_size(modify_info['rebatch_info']) self.remove_node_by_node_states(modify_info['node_states']) self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_attr(modify_info['node_changed_attr']) - self.add_node(modify_info['added_node_info'], modify_info['node_states']) + self.add_nodes(modify_info['added_node_info'], modify_info['node_states']) self.add_outputs(modify_info['added_outputs']) def check_and_save_model(self, save_dir='./modified_onnx'): @@ -218,7 +229,7 @@ class onnxModifier: os.mkdir(save_dir) save_path = os.path.join(save_dir, 'modified_' + self.model_name) - # adding new node like self.add_node() and self.modify_node_attr() can not guarantee the nodes are topologically sorted + # adding new node like self.add_nodes() and self.modify_node_attr() can not guarantee the nodes are topologically sorted # so `onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted` will be invoked # I turn off the onnx checker as a workaround. # onnx.checker.check_model(self.model_proto) @@ -226,7 +237,10 @@ class onnxModifier: print("model saved in {} !".format(save_dir)) def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None): - import onnxruntime as rt + import onnxruntime as rt + model_proto_bytes = onnx._serialize(self.model_proto) + inference_session = rt.InferenceSession(model_proto_bytes) + if not x: x = np.random.randn(*input_shape).astype(np.float32) if not output_names: @@ -234,20 +248,16 @@ class onnxModifier: # output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.INT64, shape=[]) output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape=[]) self.graph.output.append(output_value_info) - - model_proto_bytes = onnx._serialize(self.model_proto) - inference_session = rt.InferenceSession(model_proto_bytes) + output_names = [inference_session.get_outputs()[0].name] input_name = inference_session.get_inputs()[0].name - output_name = inference_session.get_outputs()[0].name - - # This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506 - out = inference_session.run(None, {input_name: x})[0] + out = inference_session.run(output_names, {input_name: x})[0] print(out.shape) if __name__ == "__main__": - model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_EyeNet.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) def explore_basic(): @@ -315,7 +325,7 @@ if __name__ == "__main__": def test_add_node(): node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}} - onnx_modifier.add_node(node_info) + onnx_modifier.add_nodes(node_info) onnx_modifier.inference() onnx_modifier.check_and_save_model() @@ -357,4 +367,11 @@ if __name__ == "__main__": onnx_modifier.check_and_save_model() # test_change_batch_size() - \ No newline at end of file + + def test_modify_initializer(): + onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368']) + onnx_modifier.modify_initializer({'onnx::Reshape_367': ['int64', '[1, 2, 32, 24, 6]']}) + onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368']) + test_modify_initializer() + + \ No newline at end of file diff --git a/static/index.js b/static/index.js index f5e3bb9..ad48f67 100644 --- a/static/index.js +++ b/static/index.js @@ -232,7 +232,8 @@ host.BrowserHost = class { 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)), 'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs, this._view._graph._renameMap, this._view._graph._modelNodeName2State)), - 'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo) + 'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo), + 'changed_initializer' : this.mapToObjectRec(this._view._graph._initializerEditInfo) }) }).then(function (response) { return response.text(); diff --git a/static/onnx.js b/static/onnx.js index a3e0c60..5e6f23c 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -444,6 +444,7 @@ onnx.Graph = class { // console.log(graph) for (const initializer of graph.initializer) { const tensor = context.tensor(initializer.name); + // console.log(initializer) // type: TensorProto tensor.initializer = new onnx.Tensor(context, initializer, 'Initializer'); } for (const sparse_initializer of graph.sparse_initializer) { @@ -566,6 +567,9 @@ onnx.Graph = class { arg_list = [this._context.argument(arg_name)] } + for (var arg of arg_list) { + arg.is_custom_added = true; + } inputs.push(new onnx.Parameter(input.name, arg_list)); } @@ -596,6 +600,10 @@ onnx.Graph = class { arg_list = [this._context.argument(arg_name)] } + + for (var arg of arg_list) { + arg.is_custom_added = true; + } outputs.push(new onnx.Parameter(output.name, arg_list)); } @@ -677,7 +685,8 @@ onnx.Argument = class { this._annotation = annotation; this._description = description || ''; - this.original_name = original_name || name + this.original_name = original_name || name; + this.is_custom_added = false; } @@ -1766,6 +1775,7 @@ onnx.GraphContext = class { argument(name, original_name) { const tensor = this.tensor(name); + // console.log(tensor) const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; return new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name); diff --git a/static/view-grapher.js b/static/view-grapher.js index 43a19ae..d53b191 100644 --- a/static/view-grapher.js +++ b/static/view-grapher.js @@ -27,6 +27,7 @@ grapher.Graph = class { this._addedOutputs = []; this._reBatchInfo = new Map(); + this._initializerEditInfo = new Map(); } get options() { diff --git a/static/view-sidebar.css b/static/view-sidebar.css index bc255ea..e5b9ed9 100644 --- a/static/view-sidebar.css +++ b/static/view-sidebar.css @@ -24,6 +24,7 @@ .sidebar-view-item-value-line-link { padding: 4px 6px 4px 6px; cursor: default; } .sidebar-view-item-value-line-link:hover { text-decoration: underline; } .sidebar-view-item-value-line-border { padding: 4px 6px 4px 6px; border-top: 1px solid rgba(27, 31, 35, 0.05); } +.sidebar-view-item-value-border { padding: 4px 6px 4px 6px;} .sidebar-view-item-value-line-content { white-space: pre; word-wrap: normal; overflow: auto; display: block; } .sidebar-view-item-value-expander { font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace; float: right; color: #aaa; cursor: pointer; user-select: none; -webkit-user-select: none; -moz-user-select: none; padding: 4px 6px 4px 6px; } .sidebar-view-item-value-expander:hover { color: #000; } diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 85ce65e..07bbe30 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -132,6 +132,7 @@ sidebar.NodeSidebar = class { this._attributes = []; this._inputs = []; this._outputs = []; + // console.log(node) // onnx.Node if (node.type) { let showDocumentation = null; @@ -305,7 +306,7 @@ sidebar.NodeSidebar = class { } _addInput(name, input, param_idx) { - // console.log(input) + // console.log(input) // type: onnx.Parameter if (input.arguments.length > 0) { const view = new sidebar.ParameterView(this._host, input, 'input', param_idx, this._modelNodeName); view.on('export-tensor', (sender, tensor) => { @@ -875,7 +876,9 @@ sidebar.ArgumentView = class { const quantization = argument.quantization; const type = argument.type; const location = this._argument.location !== undefined; - if (type || initializer || quantization || location) { + const is_custom_added = argument.is_custom_added; + // console.log(argument) + if (type || initializer || quantization || location || is_custom_added) { this._expander = this._host.document.createElement('div'); this._expander.className = 'sidebar-view-item-value-expander'; this._expander.innerText = '+'; @@ -949,6 +952,7 @@ sidebar.ArgumentView = class { this._expander.innerText = '-'; const initializer = this._argument.initializer; + // console.log(this._argument, initializer) // type: onnx.Argument, onnx.Tensor if (this._hasId && this._hasKind) { const kindLine = this._host.document.createElement('div'); kindLine.className = 'sidebar-view-item-value-line-border'; @@ -998,8 +1002,29 @@ sidebar.ArgumentView = class { location.innerHTML = 'location: ' + '' + this._argument.location + ''; this._element.appendChild(location); } + + if (initializer || this._argument.is_custom_added) { + const editInitializer = this._host.document.createElement('div'); + editInitializer.className = 'sidebar-view-item-value-line-border'; + editInitializer.innerHTML = 'If this is an initializer, you can input new value for it here:'; + this._element.appendChild(editInitializer); + + var inputInitializer = document.createElement("INPUT"); + inputInitializer.setAttribute("type", "text"); + inputInitializer.setAttribute("size", "42"); + inputInitializer.addEventListener('input', (e) => { + // console.log(e.target.value) + this._host._view._graph.changeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, this._argument.type._dataType, e.target.value); + }); + this._element.appendChild(inputInitializer); + } if (initializer) { + // to edit the existed initializer + const origInitLine = this._host.document.createElement('div'); + origInitLine.className = 'sidebar-view-item-value-line-border'; + origInitLine.innerHTML = 'original initializer value:'; + this._element.appendChild(origInitLine); const contentLine = this._host.document.createElement('pre'); const valueLine = this._host.document.createElement('div'); try { @@ -1016,10 +1041,11 @@ sidebar.ArgumentView = class { this._element.appendChild(this._saveButton); } - valueLine.className = 'sidebar-view-item-value-line-border'; + // valueLine.className = 'sidebar-view-item-value-line-border'; + valueLine.className = 'sidebar-view-item-value-border' contentLine.innerHTML = state || initializer.toString(); - console.log(initializer) - console.log(state, initializer.toString()) + // console.log(initializer) + // console.log(state, initializer.toString()) } catch (err) { contentLine.innerHTML = err.toString(); diff --git a/static/view.js b/static/view.js index 6551369..9d356bc 100644 --- a/static/view.js +++ b/static/view.js @@ -1121,6 +1121,7 @@ view.Graph = class extends grapher.Graph { } add(graph) { + // console.log(graph) // type: onnx.Graph const clusters = new Set(); const clusterParentMap = new Map(); const groups = graph.groups; @@ -1145,6 +1146,7 @@ view.Graph = class extends grapher.Graph { } for (var node of graph.nodes) { + // console.log(node) // type: onnx.Node var viewNode = this.createNode(node); var inputs = node.inputs; @@ -1377,8 +1379,31 @@ view.Graph = class extends grapher.Graph { this.view._updateGraph() } + getOriginalName(param_type, modelNodeName, param_index, arg_index) { + if (param_type == 'model_input') { + var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name + } + + if (param_type == 'model_output') { + modelNodeName = 'out_' + modelNodeName + // console.log(modelNodeName) + var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name + // console.log(orig_arg_name) + } + + if (param_type == 'input') { + var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name + // console.log(orig_arg_name) + } + if (param_type == 'output') { + var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name + // console.log(orig_arg_name) + } - changeNodeInputOutput(modelNodeName, parameterName, param_type, param_index, arg_index, targetValue, orig_arg_name) { + return orig_arg_name + } + + changeNodeInputOutput(modelNodeName, parameterName, param_type, param_index, arg_index, targetValue) { // changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { if (this._addedNode.has(modelNodeName)) { // for custom added node if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { @@ -1393,26 +1418,8 @@ view.Graph = class extends grapher.Graph { // console.log(this._addedNode) else { // for the nodes in the original model - - if (param_type == 'model_input') { - var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name - } - - if (param_type == 'model_output') { - modelNodeName = 'out_' + modelNodeName - // console.log(modelNodeName) - var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name - // console.log(orig_arg_name) - } - - if (param_type == 'input') { - var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name - // console.log(orig_arg_name) - } - if (param_type == 'output') { - var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name - // console.log(orig_arg_name) - } + var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index) + console.log(orig_arg_name) if (!this._renameMap.get(modelNodeName)) { this._renameMap.set(modelNodeName, new Map()); @@ -1422,9 +1429,32 @@ view.Graph = class extends grapher.Graph { } this.view._updateGraph() - } + changeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) { + // changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { + // if (this._addedNode.has(modelNodeName)) { // for custom added node + // if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { + // this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue + // } + + // if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) { + // this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue + // } + // // this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph + // } + // // console.log(this._addedNode) + + // else { // for the nodes in the original model + + var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index) + this._initializerEditInfo.set(orig_arg_name, [type, targetValue]); + // console.log(this._renameMap) + // } + console.log(this._initializerEditInfo) + + // this.view._updateGraph() + } build(document, origin) { diff --git a/utils/__init__.py b/utils/__init__.py index 2cdf9bf..c26dd06 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1 +1,2 @@ -from .make_nodes import * \ No newline at end of file +from .make_nodes import * +from .parse_tools import * \ No newline at end of file diff --git a/utils/parse_tools.py b/utils/parse_tools.py new file mode 100644 index 0000000..ead0e2d --- /dev/null +++ b/utils/parse_tools.py @@ -0,0 +1,70 @@ +import numpy as np + +def parse_value(value_str, value_type): + if value_type.startswith('int'): + return int(value_str) + elif value_type.startswith('float'): + return float(value_str) + else: + raise RuntimeError("type {} is not considered in current version. \ + You can kindly report an issue for this problem. Thanks!".format(value_type)) + +def parse_tensor(tensor_str, tensor_type): + def extract_val(): + num_str = "" + while (len(stk) > 0) and (type(stk[-1]) == str and ord('0') <= ord(stk[-1]) <= ord('9') or stk[-1] in ['+', '-', '.', 'e', 'E']): + num_str = stk.pop() + num_str + + if len(num_str) > 0: + return parse_value(num_str, tensor_type) + else: + return None + + tensor_str = tensor_str.replace(" ", "") + stk = [] + for i, c in enumerate(tensor_str): # '[' ',' ']' '.' '-' or value + if c == ",": + ext_val = extract_val() + if ext_val is not None: stk.append(ext_val) + elif c == "]": + ext_val = extract_val() + if ext_val is not None: stk.append(ext_val) + + arr = [] + while stk[-1] != '[': + arr.append(stk.pop()) + stk.pop() # the left [ + + arr.reverse() + stk.append(arr) + else: + stk.append(c) + val = stk[0] + + # wrap with numpy with the specific data type + if tensor_type == "int64": + return np.array(val, dtype=np.int64) + elif tensor_type == "int32": + return np.array(val, dtype=np.int32) + elif tensor_type == "int8": + return np.array(val, dtype=np.int8) + elif tensor_type == "float64": + return np.array(val, dtype=np.float64) + elif tensor_type == "float32": + return np.array(val, dtype=np.float32) + else: + raise RuntimeError("type {} is not considered in current version. \ + You can kindly report an issue for this problem. Thanks!".format(tensor_type)) + +if __name__ == "__main__": + # tensor_str = "1" + # tensor_str = "[1, 2, 3]" + tensor_str = "[[10, 2.3, 3],[1, 2e6, 3]]" + val = parse_tensor(tensor_str, "float32") + print(type(val), val) + + tensor_str = "[[10, 2, 3],[1, 2, 3]]" + val = parse_tensor(tensor_str, "int64") + print(type(val), val) + + \ No newline at end of file