diff --git a/app.py b/app.py index 51d3d78..8410518 100644 --- a/app.py +++ b/app.py @@ -22,6 +22,7 @@ def open_model(): @app.route('/download', methods=['POST']) def modify_and_download_model(): modify_info = request.get_json() + print(modify_info) onnx_modifier.reload() # allow downloading for multiple times onnx_modifier.modify(modify_info) onnx_modifier.check_and_save_model() diff --git a/docs/rebatch.gif b/docs/rebatch.gif new file mode 100644 index 0000000..eb5a2bd Binary files /dev/null and b/docs/rebatch.gif differ diff --git a/onnx_modifier.py b/onnx_modifier.py index 5bb4029..bca2616 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -5,6 +5,7 @@ import os import copy +import struct import numpy as np import onnx from utils import make_new_node, make_attr_changed_node @@ -58,7 +59,39 @@ class onnxModifier: self.initializer_name2module = dict() for initializer in self.initializer: self.initializer_name2module[initializer.name] = initializer - + + def change_batch_size(self, rebatch_info): + # https://github.com/onnx/onnx/issues/2182#issuecomment-881752539 + rebatch_type = rebatch_info['type'] + rebatch_value = rebatch_info['value'] + if type == 'fixed': + rebatch_value = int(rebatch_value) + # print(rebatch_type, rebatch_value) + + # Change batch size in input, output and value_info + for tensor in list(self.graph.input) + list(self.graph.value_info) + list(self.graph.output): + tensor.type.tensor_type.shape.dim[0].dim_param = rebatch_value + + # handle reshapes + for node in self.graph.node: + if node.op_type != 'Reshape': + continue + for init in self.graph.initializer: + # node.input[1] is expected to be a reshape + if init.name != node.input[1]: + continue + + v = rebatch_value if rebatch_type == 'fixed' else -1 + # Shape is stored as a list of ints + if len(init.int64_data) > 0: + # This overwrites bias nodes' reshape shape but should be fine + init.int64_data[0] = v + # Shape is stored as bytes + elif len(init.raw_data) > 0: + shape = bytearray(init.raw_data) + struct.pack_into('q', shape, 0, v) + init.raw_data = bytes(shape) + def remove_node_by_name(self, node_name): # remove node in graph self.graph.node.remove(self.node_name2module[node_name]) @@ -167,6 +200,7 @@ class onnxModifier: # print(modify_info['node_changed_attr']) # print(modify_info['added_node_info']) # print(modify_info['added_outputs']) + self.change_batch_size(modify_info['rebatch_info']) self.remove_node_by_node_states(modify_info['node_states']) self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_attr(modify_info['node_changed_attr']) @@ -185,10 +219,9 @@ class onnxModifier: # onnx.checker.check_model(self.model_proto) onnx.save(self.model_proto, save_path) - def inference(self, x=None, output_names=None): + def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None): import onnxruntime as rt if not x: - input_shape = [1, 3, 224, 224] x = np.random.randn(*input_shape).astype(np.float32) if not output_names: output_name = self.graph.node[-1].output[0] @@ -207,13 +240,7 @@ class onnxModifier: print(out.shape) if __name__ == "__main__": - # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" - # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" - # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" - # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12.onnx" - model_path = "C:\\Users\\ZhangGe\\Desktop\\with-added-output-modified_modified_squeezenet1.0-12.onnx" - # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" // TODO: this model is not supported well , but why? - # model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\best.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) def explore_basic(): @@ -261,7 +288,6 @@ if __name__ == "__main__": for initializer in onnx_modifier.initializer: print(initializer.name) - # print('\nleft nodes:') # for node in onnx_modifier.graph.node: # print(node.name) @@ -306,4 +332,22 @@ if __name__ == "__main__": onnx_modifier.add_outputs(['fire2/squeeze1x1_1']) # print(onnx_modifier.graph.output) onnx_modifier.check_and_save_model() - # test_add_output() \ No newline at end of file + # test_add_output() + + def test_change_batch_size(): + onnx_modifier.inference(input_shape=(1, 3, 640, 640)) + print("batch size 1 passed") + + onnx_modifier.reload() + onnx_modifier.change_batch_size({'type': 'fixed', 'value': '2'}) + onnx_modifier.inference(input_shape=(2, 3, 640, 640)) + print("batch size 2 passed") + + onnx_modifier.reload() + onnx_modifier.change_batch_size({'type': 'dynamic', 'value': 'dynamic'}) + onnx_modifier.inference(input_shape=(6, 3, 640, 640)) + print("batch size dynamic passed") + + onnx_modifier.check_and_save_model() + # test_change_batch_size() + \ No newline at end of file diff --git a/readme.md b/readme.md index 3cb121e..8e9ecca 100644 --- a/readme.md +++ b/readme.md @@ -23,6 +23,7 @@ Currently, the following editing operations are supported: - [x] Add new model outputs - [x] Edit the attribute of nodes - [x] Add new nodes (experimental) +- [x] Change batch size Here is the [update log](./docs/update_log.md) and [TODO list](./docs/todo_list.md). @@ -189,6 +190,17 @@ The following are some notes for this feature: 7. This feature is experimentally supported now and may be not very robust. So any issues are warmly welcomed if some unexpected results are encountered. +## Change batch size +`onnx-modifier` supports editing batch size now. Both `Dynamic batch size` and `Fixed batch size` modes are supported. +- `Dynamic batch size`: Click the `Dynamic batch size` button, then we get a model which supports dynamic batch size inferece; +- `Fixed batch size`: Input the fixed batch size we want, then we are done; + + + +Note the differences between `fixed batch size inference` and `dynamic batch size inference`, as [this blog](https://nietras.com/2021/05/24/set-dynamic-batch-size-using-onnx-sharp/) illustrates: +> - When running a model with only fixed dimensions, the ONNX Runtime will prepare and optimize the graph for execution when constructing the Inference Session. +> - when the model has dynamic dimensions like batch size, the ONNX Runtime may instead cache optimized graphs for specific batch sizes when inputs are first encountered for that batch size. + # Sample models For quick testing, some typical sample models are provided as following. Most of them are from [onnx model zoo](https://github.com/onnx/models) diff --git a/readme_zh-CN.md b/readme_zh-CN.md index 5f66f60..be18f97 100644 --- a/readme_zh-CN.md +++ b/readme_zh-CN.md @@ -24,6 +24,7 @@ - [x] 增加模型输出节点 - [x] 编辑节点属性值 - [x] 增加新节点(beta) +- [x] 修改模型batch size `onnx-modifier`基于流行的模型可视化工具 [Netron](https://github.com/lutzroeder/netron) 和轻量级Web应用框架 [flask](https://github.com/pallets/flask) 开发。希望它能给社区带来一些贡献~ @@ -155,7 +156,12 @@ 6. 在当前版本中,如果一个节点的输入/输出是一个列表类型(如`Concat`),限制最多显示8个。如果一个节点实际输入/输出小于8个,则填写对应数目的输入输出即可,多出来的应以`list_custom`开头,它们会在后续处理中自动被忽略。 7. 这个功能还处在开发中,可能会不够鲁棒。所以如果大家在实际使用时碰到问题,非常欢迎提issue! +## 修改模型batch size +动态batch size和固定batch size均已支持。 +- 动态batch size:点击`Dynamic batch size`即可; +- 动态bacth size:在`Fixed batch size`后方输入框内填入预期的batch size值; + `onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用,提issue,如果有帮助的话,感谢给个:star:~ diff --git a/static/index.js b/static/index.js index bf7e8ac..f5e3bb9 100644 --- a/static/index.js +++ b/static/index.js @@ -231,7 +231,8 @@ host.BrowserHost = class { 'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes), 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)), 'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs, - this._view._graph._renameMap, this._view._graph._modelNodeName2State)) + this._view._graph._renameMap, this._view._graph._modelNodeName2State)), + 'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo) }) }).then(function (response) { return response.text(); diff --git a/static/view-grapher.js b/static/view-grapher.js index 169a414..43a19ae 100644 --- a/static/view-grapher.js +++ b/static/view-grapher.js @@ -25,6 +25,8 @@ grapher.Graph = class { this._addedNode = new Map(); this._addedOutputs = []; + + this._reBatchInfo = new Map(); } get options() { diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 71f7ffc..b2d6f89 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -1146,6 +1146,10 @@ sidebar.ModelSidebar = class { const separator = this._host.document.createElement('div'); separator.className = 'sidebar-view-separator'; this._elements.push(separator); + + this._addHeader('Batch size changing helper'); + this._addRebatcher(); + } render() { @@ -1164,6 +1168,38 @@ sidebar.ModelSidebar = class { this._elements.push(item.render()); } + _addRebatcher() { + this._addButton("Dynamic batch size"); + + var fixed_batch_size_title = this._host.document.createElement('span'); + fixed_batch_size_title.innerHTML = "    or   Fixed batch size   "; + fixed_batch_size_title.setAttribute('style','font-size:14px'); + this._elements.push(fixed_batch_size_title); + + var fixed_batch_size_value = this._host.document.createElement("INPUT"); + fixed_batch_size_value.setAttribute("type", "text"); + fixed_batch_size_value.setAttribute("size", "5"); + fixed_batch_size_value.setAttribute("value", 1); + fixed_batch_size_value.addEventListener('input', (e) => { + this._host._view._graph.changeBatchSize('fixed', e.target.value); + }); + + this._elements.push(fixed_batch_size_value); + } + + _addButton(title) { + const buttonElement = this._host.document.createElement('button'); + buttonElement.className = 'sidebar-view-button'; + buttonElement.innerText = title; + this._elements.push(buttonElement); + + if (title === 'Dynamic batch size') { + buttonElement.addEventListener('click', () => { + this._host._view._graph.changeBatchSize("dynamic") + }); + } + } + addArgument(name, argument, index, arg_type) { // const view = new sidebar.ParameterView(this._host, argument); const view = new sidebar.ParameterView(this._host, argument, arg_type, index, name); diff --git a/static/view.js b/static/view.js index 1cdc2d2..ef1e5aa 100644 --- a/static/view.js +++ b/static/view.js @@ -1254,6 +1254,17 @@ view.Graph = class extends grapher.Graph { this.view._updateGraph(); } + changeBatchSize(type, value) { + if (type === "fixed") { + this._reBatchInfo.set("type", "fixed"); + this._reBatchInfo.set("value", value); + } + else { // dynamic + this._reBatchInfo.set("type", "dynamic"); + this._reBatchInfo.set("value", "dynamic"); + } + } + resetGraph() { // reset node states for (const nodeId of this.nodes.keys()) {