1123
ZhangGe6 3 years ago
parent 49de3e7f21
commit 7e8e7e9ffa

@ -22,6 +22,7 @@ def open_model():
@app.route('/download', methods=['POST']) @app.route('/download', methods=['POST'])
def modify_and_download_model(): def modify_and_download_model():
modify_info = request.get_json() modify_info = request.get_json()
print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times onnx_modifier.reload() # allow downloading for multiple times
onnx_modifier.modify(modify_info) onnx_modifier.modify(modify_info)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()

Binary file not shown.

After

Width:  |  Height:  |  Size: 865 KiB

@ -5,6 +5,7 @@
import os import os
import copy import copy
import struct
import numpy as np import numpy as np
import onnx import onnx
from utils import make_new_node, make_attr_changed_node from utils import make_new_node, make_attr_changed_node
@ -59,6 +60,38 @@ class onnxModifier:
for initializer in self.initializer: for initializer in self.initializer:
self.initializer_name2module[initializer.name] = initializer self.initializer_name2module[initializer.name] = initializer
def change_batch_size(self, rebatch_info):
# https://github.com/onnx/onnx/issues/2182#issuecomment-881752539
rebatch_type = rebatch_info['type']
rebatch_value = rebatch_info['value']
if type == 'fixed':
rebatch_value = int(rebatch_value)
# print(rebatch_type, rebatch_value)
# Change batch size in input, output and value_info
for tensor in list(self.graph.input) + list(self.graph.value_info) + list(self.graph.output):
tensor.type.tensor_type.shape.dim[0].dim_param = rebatch_value
# handle reshapes
for node in self.graph.node:
if node.op_type != 'Reshape':
continue
for init in self.graph.initializer:
# node.input[1] is expected to be a reshape
if init.name != node.input[1]:
continue
v = rebatch_value if rebatch_type == 'fixed' else -1
# Shape is stored as a list of ints
if len(init.int64_data) > 0:
# This overwrites bias nodes' reshape shape but should be fine
init.int64_data[0] = v
# Shape is stored as bytes
elif len(init.raw_data) > 0:
shape = bytearray(init.raw_data)
struct.pack_into('q', shape, 0, v)
init.raw_data = bytes(shape)
def remove_node_by_name(self, node_name): def remove_node_by_name(self, node_name):
# remove node in graph # remove node in graph
self.graph.node.remove(self.node_name2module[node_name]) self.graph.node.remove(self.node_name2module[node_name])
@ -167,6 +200,7 @@ class onnxModifier:
# print(modify_info['node_changed_attr']) # print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info']) # print(modify_info['added_node_info'])
# print(modify_info['added_outputs']) # print(modify_info['added_outputs'])
self.change_batch_size(modify_info['rebatch_info'])
self.remove_node_by_node_states(modify_info['node_states']) self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr']) self.modify_node_attr(modify_info['node_changed_attr'])
@ -185,10 +219,9 @@ class onnxModifier:
# onnx.checker.check_model(self.model_proto) # onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path) onnx.save(self.model_proto, save_path)
def inference(self, x=None, output_names=None): def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None):
import onnxruntime as rt import onnxruntime as rt
if not x: if not x:
input_shape = [1, 3, 224, 224]
x = np.random.randn(*input_shape).astype(np.float32) x = np.random.randn(*input_shape).astype(np.float32)
if not output_names: if not output_names:
output_name = self.graph.node[-1].output[0] output_name = self.graph.node[-1].output[0]
@ -207,13 +240,7 @@ class onnxModifier:
print(out.shape) print(out.shape)
if __name__ == "__main__": if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" model_path = "C:\\Users\\ZhangGe\\Desktop\\best.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\with-added-output-modified_modified_squeezenet1.0-12.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" // TODO: this model is not supported well , but why?
# model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path) onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic(): def explore_basic():
@ -261,7 +288,6 @@ if __name__ == "__main__":
for initializer in onnx_modifier.initializer: for initializer in onnx_modifier.initializer:
print(initializer.name) print(initializer.name)
# print('\nleft nodes:') # print('\nleft nodes:')
# for node in onnx_modifier.graph.node: # for node in onnx_modifier.graph.node:
# print(node.name) # print(node.name)
@ -307,3 +333,21 @@ if __name__ == "__main__":
# print(onnx_modifier.graph.output) # print(onnx_modifier.graph.output)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
# test_add_output() # test_add_output()
def test_change_batch_size():
onnx_modifier.inference(input_shape=(1, 3, 640, 640))
print("batch size 1 passed")
onnx_modifier.reload()
onnx_modifier.change_batch_size({'type': 'fixed', 'value': '2'})
onnx_modifier.inference(input_shape=(2, 3, 640, 640))
print("batch size 2 passed")
onnx_modifier.reload()
onnx_modifier.change_batch_size({'type': 'dynamic', 'value': 'dynamic'})
onnx_modifier.inference(input_shape=(6, 3, 640, 640))
print("batch size dynamic passed")
onnx_modifier.check_and_save_model()
# test_change_batch_size()

@ -23,6 +23,7 @@ Currently, the following editing operations are supported:
- [x] Add new model outputs - [x] Add new model outputs
- [x] Edit the attribute of nodes - [x] Edit the attribute of nodes
- [x] Add new nodes (experimental) - [x] Add new nodes (experimental)
- [x] Change batch size
Here is the [update log](./docs/update_log.md) and [TODO list](./docs/todo_list.md). Here is the [update log](./docs/update_log.md) and [TODO list](./docs/todo_list.md).
@ -189,6 +190,17 @@ The following are some notes for this feature:
7. This feature is experimentally supported now and may be not very robust. So any issues are warmly welcomed if some unexpected results are encountered. 7. This feature is experimentally supported now and may be not very robust. So any issues are warmly welcomed if some unexpected results are encountered.
## Change batch size
`onnx-modifier` supports editing batch size now. Both `Dynamic batch size` and `Fixed batch size` modes are supported.
- `Dynamic batch size`: Click the `Dynamic batch size` button, then we get a model which supports dynamic batch size inferece;
- `Fixed batch size`: Input the fixed batch size we want, then we are done;
<img src="./docs/rebatch.gif" style="zoom:75%;" />
Note the differences between `fixed batch size inference` and `dynamic batch size inference`, as [this blog](https://nietras.com/2021/05/24/set-dynamic-batch-size-using-onnx-sharp/) illustrates:
> - When running a model with only fixed dimensions, the ONNX Runtime will prepare and optimize the graph for execution when constructing the Inference Session.
> - when the model has dynamic dimensions like batch size, the ONNX Runtime may instead cache optimized graphs for specific batch sizes when inputs are first encountered for that batch size.
# Sample models # Sample models
For quick testing, some typical sample models are provided as following. Most of them are from [onnx model zoo](https://github.com/onnx/models) For quick testing, some typical sample models are provided as following. Most of them are from [onnx model zoo](https://github.com/onnx/models)

@ -24,6 +24,7 @@
- [x] 增加模型输出节点 - [x] 增加模型输出节点
- [x] 编辑节点属性值 - [x] 编辑节点属性值
- [x] 增加新节点beta - [x] 增加新节点beta
- [x] 修改模型batch size
`onnx-modifier`基于流行的模型可视化工具 [Netron](https://github.com/lutzroeder/netron) 和轻量级Web应用框架 [flask](https://github.com/pallets/flask) 开发。希望它能给社区带来一些贡献~ `onnx-modifier`基于流行的模型可视化工具 [Netron](https://github.com/lutzroeder/netron) 和轻量级Web应用框架 [flask](https://github.com/pallets/flask) 开发。希望它能给社区带来一些贡献~
@ -155,7 +156,12 @@
6. 在当前版本中,如果一个节点的输入/输出是一个列表类型(如`Concat`限制最多显示8个。如果一个节点实际输入/输出小于8个则填写对应数目的输入输出即可多出来的应以`list_custom`开头,它们会在后续处理中自动被忽略。 6. 在当前版本中,如果一个节点的输入/输出是一个列表类型(如`Concat`限制最多显示8个。如果一个节点实际输入/输出小于8个则填写对应数目的输入输出即可多出来的应以`list_custom`开头,它们会在后续处理中自动被忽略。
7. 这个功能还处在开发中可能会不够鲁棒。所以如果大家在实际使用时碰到问题非常欢迎提issue! 7. 这个功能还处在开发中可能会不够鲁棒。所以如果大家在实际使用时碰到问题非常欢迎提issue!
## 修改模型batch size
动态batch size和固定batch size均已支持。
- 动态batch size点击`Dynamic batch size`即可;
- 动态bacth size在`Fixed batch size`后方输入框内填入预期的batch size值
<img src="./docs/rebatch.gif" style="zoom:75%;" />
`onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用提issue如果有帮助的话感谢给个:star:~ `onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用提issue如果有帮助的话感谢给个:star:~

@ -231,7 +231,8 @@ host.BrowserHost = class {
'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes), 'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes),
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)), 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)),
'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs, 'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs,
this._view._graph._renameMap, this._view._graph._modelNodeName2State)) this._view._graph._renameMap, this._view._graph._modelNodeName2State)),
'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo)
}) })
}).then(function (response) { }).then(function (response) {
return response.text(); return response.text();

@ -25,6 +25,8 @@ grapher.Graph = class {
this._addedNode = new Map(); this._addedNode = new Map();
this._addedOutputs = []; this._addedOutputs = [];
this._reBatchInfo = new Map();
} }
get options() { get options() {

@ -1146,6 +1146,10 @@ sidebar.ModelSidebar = class {
const separator = this._host.document.createElement('div'); const separator = this._host.document.createElement('div');
separator.className = 'sidebar-view-separator'; separator.className = 'sidebar-view-separator';
this._elements.push(separator); this._elements.push(separator);
this._addHeader('Batch size changing helper');
this._addRebatcher();
} }
render() { render() {
@ -1164,6 +1168,38 @@ sidebar.ModelSidebar = class {
this._elements.push(item.render()); this._elements.push(item.render());
} }
_addRebatcher() {
this._addButton("Dynamic batch size");
var fixed_batch_size_title = this._host.document.createElement('span');
fixed_batch_size_title.innerHTML = "&nbsp;&nbsp;&nbsp;<strong> or </strong>&nbsp;&nbsp;Fixed batch size&nbsp;&nbsp;&nbsp;";
fixed_batch_size_title.setAttribute('style','font-size:14px');
this._elements.push(fixed_batch_size_title);
var fixed_batch_size_value = this._host.document.createElement("INPUT");
fixed_batch_size_value.setAttribute("type", "text");
fixed_batch_size_value.setAttribute("size", "5");
fixed_batch_size_value.setAttribute("value", 1);
fixed_batch_size_value.addEventListener('input', (e) => {
this._host._view._graph.changeBatchSize('fixed', e.target.value);
});
this._elements.push(fixed_batch_size_value);
}
_addButton(title) {
const buttonElement = this._host.document.createElement('button');
buttonElement.className = 'sidebar-view-button';
buttonElement.innerText = title;
this._elements.push(buttonElement);
if (title === 'Dynamic batch size') {
buttonElement.addEventListener('click', () => {
this._host._view._graph.changeBatchSize("dynamic")
});
}
}
addArgument(name, argument, index, arg_type) { addArgument(name, argument, index, arg_type) {
// const view = new sidebar.ParameterView(this._host, argument); // const view = new sidebar.ParameterView(this._host, argument);
const view = new sidebar.ParameterView(this._host, argument, arg_type, index, name); const view = new sidebar.ParameterView(this._host, argument, arg_type, index, name);

@ -1254,6 +1254,17 @@ view.Graph = class extends grapher.Graph {
this.view._updateGraph(); this.view._updateGraph();
} }
changeBatchSize(type, value) {
if (type === "fixed") {
this._reBatchInfo.set("type", "fixed");
this._reBatchInfo.set("value", value);
}
else { // dynamic
this._reBatchInfo.set("type", "dynamic");
this._reBatchInfo.set("value", "dynamic");
}
}
resetGraph() { resetGraph() {
// reset node states // reset node states
for (const nodeId of this.nodes.keys()) { for (const nodeId of this.nodes.keys()) {

Loading…
Cancel
Save