diff --git a/onnx_modifier.py b/onnx_modifier.py index 034e033..e4bf5f3 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -10,7 +10,8 @@ import struct import numpy as np import onnx from onnx import numpy_helper -from utils import make_new_node, make_attr_changed_node, parse_tensor +from utils import make_new_node, make_attr_changed_node +from utils import parse_tensor, np2onnxdtype class onnxModifier: def __init__(self, model_name, model_proto): @@ -154,7 +155,7 @@ class onnxModifier: if node.output[i] == src_name: node.output[i] = dst_name - # TODO: rename the corresponding initializer and update initializer_name2module + # rename the corresponding initializer and update initializer_name2module if src_name in self.initializer_name2module.keys(): init = self.initializer_name2module[src_name] init.name = dst_name @@ -170,8 +171,9 @@ class onnxModifier: self.graph.node.remove(self.node_name2module[node_name]) self.graph.node.append(attr_changed_node) - # update the node_name2module and initializer_name2module - self.gen_name2module_map() + # update the node_name2module + del self.node_name2module[node_name] + self.node_name2module[node_name] = attr_changed_node def add_nodes(self, nodes_info, node_states): for node_info in nodes_info.values(): @@ -195,19 +197,36 @@ class onnxModifier: self.graph.output.extend(value_info_protos) def modify_initializer(self, changed_initializer): + # print(changed_initializer) for init_name, meta in changed_initializer.items(): # https://github.com/onnx/onnx/issues/2978 init_type, init_val_str = meta + if init_val_str == "": continue # in case we clear the input # print(init_name, init_type, init_val) init_val = parse_tensor(init_val_str, init_type) # print(init_val) - tensor = numpy_helper.from_array(init_val, init_name) - self.initializer_name2module[init_name].CopyFrom(tensor) + # for primary initilizers + if init_name in self.initializer_name2module.keys(): + tensor = numpy_helper.from_array(init_val, init_name) + self.initializer_name2module[init_name].CopyFrom(tensor) + # for custom added initilizers + else: + initializer_tensor = onnx.helper.make_tensor( + name=init_name, + data_type=np2onnxdtype(init_val.dtype), + dims=init_val.shape, + vals=init_val) + # print(initializer_tensor) + self.initializer.append(initializer_tensor) + self.initializer_name2module[init_name] = initializer_tensor def modify(self, modify_info): ''' - Some functions, such as modify_initializer(), should be placed + 1. Some functions, such as modify_initializer(), should be placed before modify_node_io_name(), to avoid name mismatch error. + 2. add_nodes() should be placed at the first place, otherwise + remove_node_by_node_states() will delete the initializer of + newly added nodes mistakenly ''' # print(modify_info['node_states']) # print(modify_info['node_renamed_io']) @@ -215,14 +234,14 @@ class onnxModifier: # print(modify_info['added_node_info']) # print(modify_info['added_outputs']) + self.add_nodes(modify_info['added_node_info'], modify_info['node_states']) self.modify_initializer(modify_info['changed_initializer']) self.change_batch_size(modify_info['rebatch_info']) self.remove_node_by_node_states(modify_info['node_states']) self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_attr(modify_info['node_changed_attr']) - self.add_nodes(modify_info['added_node_info'], modify_info['node_states']) self.add_outputs(modify_info['added_outputs']) - + def check_and_save_model(self, save_dir='./modified_onnx'): print("saving model...") if not os.path.exists(save_dir): @@ -257,7 +276,7 @@ class onnxModifier: if __name__ == "__main__": # model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx" - model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_EyeNet.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) def explore_basic(): @@ -368,10 +387,21 @@ if __name__ == "__main__": onnx_modifier.check_and_save_model() # test_change_batch_size() - def test_modify_initializer(): + def test_modify_primary_initializer(): onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368']) onnx_modifier.modify_initializer({'onnx::Reshape_367': ['int64', '[1, 2, 32, 24, 6]']}) onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368']) - test_modify_initializer() - + # test_modify_primary_initializer() + + def test_modify_new_initializer(): + modify_info = {'node_states': {'input': 'Exist', 'Conv_0': 'Exist', 'LeakyRelu_1': 'Exist', 'Conv_2': 'Exist', 'LeakyRelu_3': 'Exist', 'Conv_4': 'Exist', 'LeakyRelu_5': 'Exist', 'Conv_6': 'Exist', 'LeakyRelu_7': 'Exist', 'Conv_8': 'Exist', 'LeakyRelu_9': 'Exist', 'Conv_10': 'Exist', 'Conv_11': 'Exist', 'LeakyRelu_12': 'Exist', 'Conv_13': 'Exist', 'Conv_14': 'Exist', 'LeakyRelu_15': 'Exist', 'Conv_16': 'Exist', 'Concat_17': 'Exist', 'LeakyRelu_18': 'Exist', 'Conv_19': 'Exist', 'Sigmoid_20': 'Exist', 'Mul_22': 'Exist', 'Conv_23': 'Exist', 'LeakyRelu_24': 'Exist', 'Conv_25': 'Exist', 'Conv_26': 'Exist', 'LeakyRelu_27': 'Exist', 'Conv_28': 'Exist', 'Add_29': 'Exist', 'Conv_30': 'Exist', 'Conv_31': 'Exist', 'LeakyRelu_32': 'Exist', 'Conv_33': 'Exist', 'Conv_34': 'Exist', 'LeakyRelu_35': 'Exist', 'Conv_36': 'Exist', 'Concat_37': 'Exist', 'LeakyRelu_38': 'Exist', 'Conv_39': 'Exist', + 'Conv_40': 'Exist', 'LeakyRelu_41': 'Exist', 'Conv_42': 'Exist', 'LeakyRelu_43': 'Exist', 'Conv_44': 'Exist', 'Conv_45': 'Exist', 'LeakyRelu_46': 'Exist', 'Concat_47': 'Exist', 'Reshape_49': 'Exist', 'out_onnx::Transpose_368': 'Exist', 'custom_added_Reshape0': 'Exist', 'out_custom_output_2': 'Exist'}, 'node_renamed_io': {}, 'node_changed_attr': {}, 'added_node_info': {'custom_added_Reshape0': {'properties': {'domain': 'ai.onnx', 'op_type': 'Reshape', 'name': 'custom_added_Reshape0'}, 'attributes': {}, 'inputs': {'data': ['onnx::Transpose_368'], 'shape': ['custom_input_1']}, 'outputs': {'reshaped': ['custom_output_2']}}}, 'added_outputs': {'0': 'custom_output_2'}, 'rebatch_info': {}, 'changed_initializer': {'custom_input_1': ['int64', '[1, 2, 32, 24, 6]']}} + onnx_modifier.modify(modify_info) + onnx_modifier.check_and_save_model() + onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['custom_output_2']) + print(onnx_modifier.initializer_name2module.keys()) + for initializer in onnx_modifier.initializer: + print(f"Tensor Name: {initializer.name}, Data Type: {initializer.data_type}, Shape: {initializer.dims}") + + test_modify_new_initializer() \ No newline at end of file diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 07bbe30..edc77a8 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -1003,20 +1003,74 @@ sidebar.ArgumentView = class { this._element.appendChild(location); } - if (initializer || this._argument.is_custom_added) { - const editInitializer = this._host.document.createElement('div'); - editInitializer.className = 'sidebar-view-item-value-line-border'; - editInitializer.innerHTML = 'If this is an initializer, you can input new value for it here:'; - this._element.appendChild(editInitializer); + if (initializer) { + const editInitializerVal = this._host.document.createElement('div'); + editInitializerVal.className = 'sidebar-view-item-value-line-border'; + editInitializerVal.innerHTML = 'This is an initializer, you can input a new value for it here:'; + this._element.appendChild(editInitializerVal); - var inputInitializer = document.createElement("INPUT"); - inputInitializer.setAttribute("type", "text"); - inputInitializer.setAttribute("size", "42"); - inputInitializer.addEventListener('input', (e) => { + var inputInitializerVal = document.createElement("INPUT"); + inputInitializerVal.setAttribute("type", "text"); + inputInitializerVal.setAttribute("size", "42"); + // reload the last value + var orig_arg_name = this._host._view._graph.getOriginalName(this._param_type, this._modelNodeName, this._param_index, this._arg_index) + if (this._host._view._graph._initializerEditInfo.get(orig_arg_name)) { + // [type, value] + inputInitializerVal.setAttribute("value", this._host._view._graph._initializerEditInfo.get(orig_arg_name)[1]); + } + + inputInitializerVal.addEventListener('input', (e) => { // console.log(e.target.value) this._host._view._graph.changeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, this._argument.type._dataType, e.target.value); }); - this._element.appendChild(inputInitializer); + this._element.appendChild(inputInitializerVal); + } + + if (this._argument.is_custom_added) { + var new_init_val = "", new_init_type = ""; + // ====== input value ======> + const editInitializerVal = this._host.document.createElement('div'); + editInitializerVal.className = 'sidebar-view-item-value-line-border'; + editInitializerVal.innerHTML = 'If this is an initializer, you can input new value for it here:'; + this._element.appendChild(editInitializerVal); + + var inputInitializerVal = document.createElement("INPUT"); + inputInitializerVal.setAttribute("type", "text"); + inputInitializerVal.setAttribute("size", "42"); + // this._element.appendChild(inputInitializerVal); + + inputInitializerVal.addEventListener('input', (e) => { + // console.log(e.target.value) + new_init_val = e.target.value; + this._host._view._graph.changeAddedNodeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, new_init_type, new_init_val); + }); + this._element.appendChild(inputInitializerVal); + // <====== input value ====== + + // ====== input type ======> + const editInitializerType = this._host.document.createElement('div'); + editInitializerType.className = 'sidebar-view-item-value-line-border'; + editInitializerType.innerHTML = 'and input its type for it here (see properties->type->? for more info):'; + this._element.appendChild(editInitializerType); + + var inputInitializerType = document.createElement("INPUT"); + inputInitializerType.setAttribute("type", "text"); + inputInitializerType.setAttribute("size", "42"); + + var arg_name = this._host._view._graph._addedNode.get(this._modelNodeName).inputs.get(this._parameterName)[this._arg_index] + if (this._host._view._graph._initializerEditInfo.get(arg_name)) { + // [type, value] + inputInitializerType.setAttribute("value", this._host._view._graph._initializerEditInfo.get(arg_name)[0]); + inputInitializerVal.setAttribute("value", this._host._view._graph._initializerEditInfo.get(arg_name)[1]); + } + + inputInitializerType.addEventListener('input', (e) => { + // console.log(e.target.value) + new_init_type = e.target.value; + this._host._view._graph.changeAddedNodeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, new_init_type, new_init_val); + }); + this._element.appendChild(inputInitializerType); + // <====== input type ====== } if (initializer) { diff --git a/static/view.js b/static/view.js index 9d356bc..665f0b6 100644 --- a/static/view.js +++ b/static/view.js @@ -585,6 +585,7 @@ view.View = class { viewGraph._addedNode = this.lastViewGraph._addedNode; viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey; viewGraph._addedOutputs = this.lastViewGraph._addedOutputs; + viewGraph._initializerEditInfo = this.lastViewGraph._initializerEditInfo; // console.log(viewGraph._renameMap); // console.log(viewGraph._modelNodeName2State) @@ -1033,12 +1034,6 @@ view.View = class { } } } - - reloadLastLocation() { - const container = this._getElementById('graph'); - - - } }; view.Graph = class extends grapher.Graph { @@ -1432,29 +1427,23 @@ view.Graph = class extends grapher.Graph { } changeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) { - // changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { - // if (this._addedNode.has(modelNodeName)) { // for custom added node - // if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { - // this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue - // } - - // if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) { - // this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue - // } - // // this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph - // } - // // console.log(this._addedNode) - - // else { // for the nodes in the original model - - var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index) - this._initializerEditInfo.set(orig_arg_name, [type, targetValue]); - // console.log(this._renameMap) - // } - console.log(this._initializerEditInfo) - - // this.view._updateGraph() + if (this._addedNode.has(modelNodeName)) { // for custom added node } + // console.log(this._addedNode) + + else { // for the nodes in the original model + var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index) + this._initializerEditInfo.set(orig_arg_name, [type, targetValue]); + } + + this.view._updateGraph() + } + + changeAddedNodeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) { + var arg_name = this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] + this._initializerEditInfo.set(arg_name, [type, targetValue]); + this.view._updateGraph() + } build(document, origin) { diff --git a/utils/parse_tools.py b/utils/parse_tools.py index ead0e2d..5840f72 100644 --- a/utils/parse_tools.py +++ b/utils/parse_tools.py @@ -1,4 +1,6 @@ import numpy as np +from typing import cast +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE def parse_value(value_str, value_type): if value_type.startswith('int'): @@ -9,6 +11,7 @@ def parse_value(value_str, value_type): raise RuntimeError("type {} is not considered in current version. \ You can kindly report an issue for this problem. Thanks!".format(value_type)) +# parse numpy values from string def parse_tensor(tensor_str, tensor_type): def extract_val(): num_str = "" @@ -22,7 +25,7 @@ def parse_tensor(tensor_str, tensor_type): tensor_str = tensor_str.replace(" ", "") stk = [] - for i, c in enumerate(tensor_str): # '[' ',' ']' '.' '-' or value + for c in tensor_str: # '[' ',' ']' '.' '-' or value if c == ",": ext_val = extract_val() if ext_val is not None: stk.append(ext_val) @@ -56,6 +59,11 @@ def parse_tensor(tensor_str, tensor_type): raise RuntimeError("type {} is not considered in current version. \ You can kindly report an issue for this problem. Thanks!".format(tensor_type)) +# map np datatype to onnx datatype +# https://github.com/onnx/onnx/blob/8669fad0247799f4d8683550eec749974b4f5338/onnx/helper.py#L1177 +def np2onnxdtype(np_dtype): + return cast(int, NP_TYPE_TO_TENSOR_TYPE[np_dtype]) + if __name__ == "__main__": # tensor_str = "1" # tensor_str = "[1, 2, 3]"