diff --git a/app.py b/app.py
index 51d3d78..8410518 100644
--- a/app.py
+++ b/app.py
@@ -22,6 +22,7 @@ def open_model():
@app.route('/download', methods=['POST'])
def modify_and_download_model():
modify_info = request.get_json()
+ print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times
onnx_modifier.modify(modify_info)
onnx_modifier.check_and_save_model()
diff --git a/docs/rebatch.gif b/docs/rebatch.gif
new file mode 100644
index 0000000..eb5a2bd
Binary files /dev/null and b/docs/rebatch.gif differ
diff --git a/onnx_modifier.py b/onnx_modifier.py
index 5bb4029..bca2616 100644
--- a/onnx_modifier.py
+++ b/onnx_modifier.py
@@ -5,6 +5,7 @@
import os
import copy
+import struct
import numpy as np
import onnx
from utils import make_new_node, make_attr_changed_node
@@ -58,7 +59,39 @@ class onnxModifier:
self.initializer_name2module = dict()
for initializer in self.initializer:
self.initializer_name2module[initializer.name] = initializer
-
+
+ def change_batch_size(self, rebatch_info):
+ # https://github.com/onnx/onnx/issues/2182#issuecomment-881752539
+ rebatch_type = rebatch_info['type']
+ rebatch_value = rebatch_info['value']
+ if type == 'fixed':
+ rebatch_value = int(rebatch_value)
+ # print(rebatch_type, rebatch_value)
+
+ # Change batch size in input, output and value_info
+ for tensor in list(self.graph.input) + list(self.graph.value_info) + list(self.graph.output):
+ tensor.type.tensor_type.shape.dim[0].dim_param = rebatch_value
+
+ # handle reshapes
+ for node in self.graph.node:
+ if node.op_type != 'Reshape':
+ continue
+ for init in self.graph.initializer:
+ # node.input[1] is expected to be a reshape
+ if init.name != node.input[1]:
+ continue
+
+ v = rebatch_value if rebatch_type == 'fixed' else -1
+ # Shape is stored as a list of ints
+ if len(init.int64_data) > 0:
+ # This overwrites bias nodes' reshape shape but should be fine
+ init.int64_data[0] = v
+ # Shape is stored as bytes
+ elif len(init.raw_data) > 0:
+ shape = bytearray(init.raw_data)
+ struct.pack_into('q', shape, 0, v)
+ init.raw_data = bytes(shape)
+
def remove_node_by_name(self, node_name):
# remove node in graph
self.graph.node.remove(self.node_name2module[node_name])
@@ -167,6 +200,7 @@ class onnxModifier:
# print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info'])
# print(modify_info['added_outputs'])
+ self.change_batch_size(modify_info['rebatch_info'])
self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr'])
@@ -185,10 +219,9 @@ class onnxModifier:
# onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path)
- def inference(self, x=None, output_names=None):
+ def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None):
import onnxruntime as rt
if not x:
- input_shape = [1, 3, 224, 224]
x = np.random.randn(*input_shape).astype(np.float32)
if not output_names:
output_name = self.graph.node[-1].output[0]
@@ -207,13 +240,7 @@ class onnxModifier:
print(out.shape)
if __name__ == "__main__":
- # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
- # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
- # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
- # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12.onnx"
- model_path = "C:\\Users\\ZhangGe\\Desktop\\with-added-output-modified_modified_squeezenet1.0-12.onnx"
- # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" // TODO: this model is not supported well , but why?
- # model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx"
+ model_path = "C:\\Users\\ZhangGe\\Desktop\\best.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic():
@@ -261,7 +288,6 @@ if __name__ == "__main__":
for initializer in onnx_modifier.initializer:
print(initializer.name)
-
# print('\nleft nodes:')
# for node in onnx_modifier.graph.node:
# print(node.name)
@@ -306,4 +332,22 @@ if __name__ == "__main__":
onnx_modifier.add_outputs(['fire2/squeeze1x1_1'])
# print(onnx_modifier.graph.output)
onnx_modifier.check_and_save_model()
- # test_add_output()
\ No newline at end of file
+ # test_add_output()
+
+ def test_change_batch_size():
+ onnx_modifier.inference(input_shape=(1, 3, 640, 640))
+ print("batch size 1 passed")
+
+ onnx_modifier.reload()
+ onnx_modifier.change_batch_size({'type': 'fixed', 'value': '2'})
+ onnx_modifier.inference(input_shape=(2, 3, 640, 640))
+ print("batch size 2 passed")
+
+ onnx_modifier.reload()
+ onnx_modifier.change_batch_size({'type': 'dynamic', 'value': 'dynamic'})
+ onnx_modifier.inference(input_shape=(6, 3, 640, 640))
+ print("batch size dynamic passed")
+
+ onnx_modifier.check_and_save_model()
+ # test_change_batch_size()
+
\ No newline at end of file
diff --git a/readme.md b/readme.md
index 3cb121e..8e9ecca 100644
--- a/readme.md
+++ b/readme.md
@@ -23,6 +23,7 @@ Currently, the following editing operations are supported:
- [x] Add new model outputs
- [x] Edit the attribute of nodes
- [x] Add new nodes (experimental)
+- [x] Change batch size
Here is the [update log](./docs/update_log.md) and [TODO list](./docs/todo_list.md).
@@ -189,6 +190,17 @@ The following are some notes for this feature:
7. This feature is experimentally supported now and may be not very robust. So any issues are warmly welcomed if some unexpected results are encountered.
+## Change batch size
+`onnx-modifier` supports editing batch size now. Both `Dynamic batch size` and `Fixed batch size` modes are supported.
+- `Dynamic batch size`: Click the `Dynamic batch size` button, then we get a model which supports dynamic batch size inferece;
+- `Fixed batch size`: Input the fixed batch size we want, then we are done;
+
+
+
+Note the differences between `fixed batch size inference` and `dynamic batch size inference`, as [this blog](https://nietras.com/2021/05/24/set-dynamic-batch-size-using-onnx-sharp/) illustrates:
+> - When running a model with only fixed dimensions, the ONNX Runtime will prepare and optimize the graph for execution when constructing the Inference Session.
+> - when the model has dynamic dimensions like batch size, the ONNX Runtime may instead cache optimized graphs for specific batch sizes when inputs are first encountered for that batch size.
+
# Sample models
For quick testing, some typical sample models are provided as following. Most of them are from [onnx model zoo](https://github.com/onnx/models)
diff --git a/readme_zh-CN.md b/readme_zh-CN.md
index 5f66f60..be18f97 100644
--- a/readme_zh-CN.md
+++ b/readme_zh-CN.md
@@ -24,6 +24,7 @@
- [x] 增加模型输出节点
- [x] 编辑节点属性值
- [x] 增加新节点(beta)
+- [x] 修改模型batch size
`onnx-modifier`基于流行的模型可视化工具 [Netron](https://github.com/lutzroeder/netron) 和轻量级Web应用框架 [flask](https://github.com/pallets/flask) 开发。希望它能给社区带来一些贡献~
@@ -155,7 +156,12 @@
6. 在当前版本中,如果一个节点的输入/输出是一个列表类型(如`Concat`),限制最多显示8个。如果一个节点实际输入/输出小于8个,则填写对应数目的输入输出即可,多出来的应以`list_custom`开头,它们会在后续处理中自动被忽略。
7. 这个功能还处在开发中,可能会不够鲁棒。所以如果大家在实际使用时碰到问题,非常欢迎提issue!
+## 修改模型batch size
+动态batch size和固定batch size均已支持。
+- 动态batch size:点击`Dynamic batch size`即可;
+- 动态bacth size:在`Fixed batch size`后方输入框内填入预期的batch size值;
+
`onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用,提issue,如果有帮助的话,感谢给个:star:~
diff --git a/static/index.js b/static/index.js
index bf7e8ac..f5e3bb9 100644
--- a/static/index.js
+++ b/static/index.js
@@ -231,7 +231,8 @@ host.BrowserHost = class {
'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes),
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)),
'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs,
- this._view._graph._renameMap, this._view._graph._modelNodeName2State))
+ this._view._graph._renameMap, this._view._graph._modelNodeName2State)),
+ 'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo)
})
}).then(function (response) {
return response.text();
diff --git a/static/view-grapher.js b/static/view-grapher.js
index 169a414..43a19ae 100644
--- a/static/view-grapher.js
+++ b/static/view-grapher.js
@@ -25,6 +25,8 @@ grapher.Graph = class {
this._addedNode = new Map();
this._addedOutputs = [];
+
+ this._reBatchInfo = new Map();
}
get options() {
diff --git a/static/view-sidebar.js b/static/view-sidebar.js
index 71f7ffc..b2d6f89 100644
--- a/static/view-sidebar.js
+++ b/static/view-sidebar.js
@@ -1146,6 +1146,10 @@ sidebar.ModelSidebar = class {
const separator = this._host.document.createElement('div');
separator.className = 'sidebar-view-separator';
this._elements.push(separator);
+
+ this._addHeader('Batch size changing helper');
+ this._addRebatcher();
+
}
render() {
@@ -1164,6 +1168,38 @@ sidebar.ModelSidebar = class {
this._elements.push(item.render());
}
+ _addRebatcher() {
+ this._addButton("Dynamic batch size");
+
+ var fixed_batch_size_title = this._host.document.createElement('span');
+ fixed_batch_size_title.innerHTML = " or Fixed batch size ";
+ fixed_batch_size_title.setAttribute('style','font-size:14px');
+ this._elements.push(fixed_batch_size_title);
+
+ var fixed_batch_size_value = this._host.document.createElement("INPUT");
+ fixed_batch_size_value.setAttribute("type", "text");
+ fixed_batch_size_value.setAttribute("size", "5");
+ fixed_batch_size_value.setAttribute("value", 1);
+ fixed_batch_size_value.addEventListener('input', (e) => {
+ this._host._view._graph.changeBatchSize('fixed', e.target.value);
+ });
+
+ this._elements.push(fixed_batch_size_value);
+ }
+
+ _addButton(title) {
+ const buttonElement = this._host.document.createElement('button');
+ buttonElement.className = 'sidebar-view-button';
+ buttonElement.innerText = title;
+ this._elements.push(buttonElement);
+
+ if (title === 'Dynamic batch size') {
+ buttonElement.addEventListener('click', () => {
+ this._host._view._graph.changeBatchSize("dynamic")
+ });
+ }
+ }
+
addArgument(name, argument, index, arg_type) {
// const view = new sidebar.ParameterView(this._host, argument);
const view = new sidebar.ParameterView(this._host, argument, arg_type, index, name);
diff --git a/static/view.js b/static/view.js
index 1cdc2d2..ef1e5aa 100644
--- a/static/view.js
+++ b/static/view.js
@@ -1254,6 +1254,17 @@ view.Graph = class extends grapher.Graph {
this.view._updateGraph();
}
+ changeBatchSize(type, value) {
+ if (type === "fixed") {
+ this._reBatchInfo.set("type", "fixed");
+ this._reBatchInfo.set("value", value);
+ }
+ else { // dynamic
+ this._reBatchInfo.set("type", "dynamic");
+ this._reBatchInfo.set("value", "dynamic");
+ }
+ }
+
resetGraph() {
// reset node states
for (const nodeId of this.nodes.keys()) {