diff --git a/docs/onnx_modifier_todo.md b/docs/onnx_modifier_todo.md index 2f79c90..672e5dc 100644 --- a/docs/onnx_modifier_todo.md +++ b/docs/onnx_modifier_todo.md @@ -8,6 +8,10 @@ add node (NMS): https://github.com/onnx/onnx/issues/2216 add node (add preprocess nodes): https://zhuanlan.zhihu.com/p/394395167 +combine models: + - https://stackoverflow.com/questions/66178085/can-i-combine-two-onnx-graphs-together-passing-the-output-from-one-as-input-to + - https://www.zhihu.com/people/kai-xin-zui-zhong-yao-76/posts + # modify attribute of nodes topk: https://github.com/onnx/onnx/issues/2921 @@ -26,11 +30,14 @@ http://yyixx.com/docs/algo/onnx/ # 待做的 -bug(fixed): 不可连续添加某一种类型的节点(无反应) +boost: 支持添加更复杂的节点 boost: 直接使用侧边栏inputs/outputs属性框完成重命名,并提供reset功能 boost: 支持处理属性的修改 -boost: 支持添加更复杂的节点 -question: 在add()函数里,为什么对conv的inputs进行遍历,只能得到X,而得不到W和B? + - 暂时只支持新增节点,而不支持已有节点开放(防止出现不必要的错误) + +(fixed)bug: 不可连续添加某一种类型的节点(无反应) +(solved)question: 在add()函数里,为什么对conv的inputs进行遍历,只能得到X,而得不到W和B? + - 因为遍历的条件里,有判断是否有initializer的条件 # 其他 diff --git a/onnx_modifier.py b/onnx_modifier.py index 4fa6856..5eb6d03 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -5,7 +5,10 @@ import os import copy +import numpy as np import onnx +import onnxruntime as rt +from utils import make_node class onnxModifier: def __init__(self, model_name, model_proto): @@ -63,6 +66,9 @@ class onnxModifier: def remove_node_by_node_states(self, node_states): # remove node from graph for node_name, node_state in node_states.items(): + if not (node_name in self.node_name2module): + # for custom added node here + continue if node_state == 'Deleted': if node_name in self.graph_output_names: # print('removing output {} ...'.format(node_name)) @@ -98,31 +104,15 @@ class onnxModifier: def add_node(self, nodes_info): for node_info in nodes_info.values(): - name = node_info['properties']['name'] - op_type = node_info['properties']['op_type'] - attributes = node_info['attributes'] - - inputs = [] - for key in node_info['inputs'].keys(): - inputs += node_info['inputs'][key] - outputs = [] - for key in node_info['outputs'].keys(): - outputs += node_info['outputs'][key] - - node = onnx.helper.make_node( - op_type=op_type, - inputs=inputs, - outputs=outputs, - name=name, - **attributes - ) + node = make_node(node_info) # print(node) self.graph.node.append(node) def modify(self, modify_info): - print(modify_info['node_renamed_io']) + # print(modify_info['node_renamed_io']) + print(modify_info['added_node_info']) self.remove_node_by_node_states(modify_info['node_states']) self.modify_node_io_name(modify_info['node_renamed_io']) self.add_node(modify_info['added_node_info']) @@ -135,11 +125,29 @@ class onnxModifier: onnx.checker.check_model(self.model_proto) onnx.save(self.model_proto, save_path) - def inference(self): - # model_proto_bytes = onnx._serialize(model_proto_from_stream) - # inference_session = rt.InferenceSession(model_proto_bytes) - pass + def inference(self, x=None, output_names=None): + if not x: + input_shape = [1, 3, 224, 224] + x = np.random.randn(*input_shape).astype(np.float32) + if not output_names: + output_name = self.graph.node[-1].output[0] + # output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.INT64, shape=[]) + output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape=[]) + self.graph.output.append(output_value_info) + + model_proto_bytes = onnx._serialize(self.model_proto) + inference_session = rt.InferenceSession(model_proto_bytes) + input_name = inference_session.get_inputs()[0].name + output_name = inference_session.get_outputs()[0].name + # print(input_name) + # print(output_name) + + # This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506 + out = inference_session.run(None, {input_name: x})[0] + + # print(out) + if __name__ == "__main__": # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" @@ -147,7 +155,25 @@ if __name__ == "__main__": # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) + + def explore_basic(): + print(type(onnx_modifier.model_proto.graph.initializer)) + print(dir(onnx_modifier.model_proto.graph.initializer)) + print(len(onnx_modifier.model_proto.graph.node)) + print(len(onnx_modifier.model_proto.graph.initializer)) + + for node in onnx_modifier.model_proto.graph.node: + print(node.name) + print(node.input) + print() + + # for initializer in onnx_modifier.model_proto.graph.initializer: + # print(initializer.name) + # print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale']) + pass + # explore_basic() + def remove_node_by_node_states(): print(len(onnx_modifier.graph.node)) print(len(onnx_modifier.graph.initializer)) @@ -184,24 +210,7 @@ if __name__ == "__main__": onnx_modifier.check_and_save_model() # remove_node_by_node_states() - def explore_basic(): - print(type(onnx_modifier.model_proto.graph.initializer)) - print(dir(onnx_modifier.model_proto.graph.initializer)) - - print(len(onnx_modifier.model_proto.graph.node)) - print(len(onnx_modifier.model_proto.graph.initializer)) - - for node in onnx_modifier.model_proto.graph.node: - print(node.name) - print(node.input) - print() - - # for initializer in onnx_modifier.model_proto.graph.initializer: - # print(initializer.name) - # print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale']) - pass - # explore_basic() - + def test_modify_node_io_name(): node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}} onnx_modifier.modify_node_io_name(node_rename_io) @@ -209,9 +218,11 @@ if __name__ == "__main__": # test_modify_node_io_name() def test_add_node(): - node_info = {'properties': {'domain': 'ai.onnx', 'op_type': 'Abs', 'name': 'custom_added_Abs0'}, 'attributes': {}, 'inputs': {'X': ['custom_input_0']}, 'outputs': {'Y': ['custom_output_1']}} + node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}} onnx_modifier.add_node(node_info) + + onnx_modifier.inference() onnx_modifier.check_and_save_model() test_add_node() diff --git a/static/index.js b/static/index.js index 79c52a3..c2f8a85 100644 --- a/static/index.js +++ b/static/index.js @@ -244,7 +244,7 @@ host.BrowserHost = class { } else { // swal("Error happens!", "You are kindly to create an issue on https://github.com/ZhangGe6/onnx-modifier", "error"); - swal("Error happens!", "You can check the log and kindly create an issue on https://github.com/ZhangGe6/onnx-modifier", "error"); + swal("Error happens!", "You are kindly to check the log and create an issue on https://github.com/ZhangGe6/onnx-modifier", "error"); // alert('Error happens, you can find it out or create an issue on https://github.com/ZhangGe6/onnx-modifier') } }); diff --git a/static/onnx.js b/static/onnx.js index 4d61676..85f6ffe 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -514,7 +514,7 @@ onnx.Graph = class { reset_custom_added_node() { this._custom_added_node = [] - this._custom_add_node_io_idx = 0 + // this._custom_add_node_io_idx = 0 } toString() { @@ -645,7 +645,7 @@ onnx.Graph = class { node_info.properties.get('op_type'), node_info.properties.get('domain'), node_info.properties.get('name'), - schema.description, + null, // schema.description, // omit it to save sidebar space. The node description can also be seen in the node `type` expander attributes, inputs, outputs @@ -693,9 +693,9 @@ onnx.Argument = class { // this._renamed = false; // this._new_name = null; - console.log(original_name) + // console.log(original_name) this.original_name = original_name || name - console.log(this.original_name) + // console.log(this.original_name) } diff --git a/static/view-sidebar.js b/static/view-sidebar.js index b5946e7..3312370 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -201,24 +201,26 @@ sidebar.NodeSidebar = class { this._elements.push(this._host.document.createElement('hr')); this.add_separator(this._elements, 'sidebar-view-separator') + this._addHeader('Node deleting helper'); this._addButton('Delete With Children'); this.add_span() this._addButton('Delete Single Node'); this.add_span() this._addButton('Recover Node'); - - this.add_separator(this._elements, 'sidebar-view-separator'); - this._addHeader('Rename helper'); - if (inputs && inputs.length > 0) { - for (const input of inputs) { - this.add_rename_aux_element(input.arguments); - } - } - if (outputs && outputs.length > 0) { - for (const output of outputs) { - this.add_rename_aux_element(output.arguments); - } - } + + // deprecated + // this.add_separator(this._elements, 'sidebar-view-separator'); + // this._addHeader('Rename helper'); + // if (inputs && inputs.length > 0) { + // for (const input of inputs) { + // this.add_rename_aux_element(input.arguments); + // } + // } + // if (outputs && outputs.length > 0) { + // for (const output of outputs) { + // this.add_rename_aux_element(output.arguments); + // } + // } // this.add_separator(this._elements, 'sidebar-view-separator'); // this._addHeader('Add children node'); @@ -648,6 +650,7 @@ class NodeAttributeView { this._element.appendChild(this._expander); } const value = this._attribute.value; + console.log(this._attribute.name, value, type) switch (type) { case 'graph': { const line = this._host.document.createElement('div'); @@ -687,8 +690,8 @@ class NodeAttributeView { attr_input.setAttribute("size", "42"); attr_input.setAttribute("value", content ? content : 'undefined'); attr_input.addEventListener('input', (e) => { - // console.log(e.target.value); - this._host._view._graph.changeNodeAttribute(this._modelNodeName, this._attributeName, e.target.value); + console.log(e.target.value); + this._host._view._graph.changeNodeAttribute(this._modelNodeName, this._attributeName, this.parse_value(e.target.value, type)); // console.log(this._host._view._graph._renameMap); }); @@ -758,6 +761,34 @@ class NodeAttributeView { } } } + + parse_value(value, type) { + if (value == 'undefined') { + // alert(""); + return value + } + + switch (type) { + case "int64": + return parseInt(value) + // case "" + case "int64[]": + var val = [] + for (var v of value.split(",")) { + val.push(parseInt(v)) + } + return val + + case "float32": + return parseFloat(value) + case "float32[]": + var val = [] + for (var v of value.split(",")) { + val.push(parseFloat(v)) + } + return val + } + } } sidebar.ParameterView = class { diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..2cdf9bf --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from .make_nodes import * \ No newline at end of file diff --git a/utils/make_nodes.py b/utils/make_nodes.py new file mode 100644 index 0000000..bc69836 --- /dev/null +++ b/utils/make_nodes.py @@ -0,0 +1,28 @@ +import onnx + +def make_node(node_info): + name = node_info['properties']['name'] + op_type = node_info['properties']['op_type'] + attributes = node_info['attributes'] + # attributes = {k: v for k, v in node_info['attributes'].items() if not v == 'undefined'} + # print(attributes) + + inputs = [] + for key in node_info['inputs'].keys(): + inputs += node_info['inputs'][key] + outputs = [] + for key in node_info['outputs'].keys(): + outputs += node_info['outputs'][key] + + # https://github.com/onnx/onnx/blob/main/onnx/helper.py#L82 + node = onnx.helper.make_node( + op_type=op_type, + inputs=inputs, + outputs=outputs, + name=name, + **attributes + ) + + # print(node) + + return node \ No newline at end of file