1123
ZhangGe6 3 years ago
parent 49de3e7f21
commit 7e8e7e9ffa

@ -22,6 +22,7 @@ def open_model():
@app.route('/download', methods=['POST'])
def modify_and_download_model():
modify_info = request.get_json()
print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times
onnx_modifier.modify(modify_info)
onnx_modifier.check_and_save_model()

Binary file not shown.

After

Width:  |  Height:  |  Size: 865 KiB

@ -5,6 +5,7 @@
import os
import copy
import struct
import numpy as np
import onnx
from utils import make_new_node, make_attr_changed_node
@ -58,7 +59,39 @@ class onnxModifier:
self.initializer_name2module = dict()
for initializer in self.initializer:
self.initializer_name2module[initializer.name] = initializer
def change_batch_size(self, rebatch_info):
# https://github.com/onnx/onnx/issues/2182#issuecomment-881752539
rebatch_type = rebatch_info['type']
rebatch_value = rebatch_info['value']
if type == 'fixed':
rebatch_value = int(rebatch_value)
# print(rebatch_type, rebatch_value)
# Change batch size in input, output and value_info
for tensor in list(self.graph.input) + list(self.graph.value_info) + list(self.graph.output):
tensor.type.tensor_type.shape.dim[0].dim_param = rebatch_value
# handle reshapes
for node in self.graph.node:
if node.op_type != 'Reshape':
continue
for init in self.graph.initializer:
# node.input[1] is expected to be a reshape
if init.name != node.input[1]:
continue
v = rebatch_value if rebatch_type == 'fixed' else -1
# Shape is stored as a list of ints
if len(init.int64_data) > 0:
# This overwrites bias nodes' reshape shape but should be fine
init.int64_data[0] = v
# Shape is stored as bytes
elif len(init.raw_data) > 0:
shape = bytearray(init.raw_data)
struct.pack_into('q', shape, 0, v)
init.raw_data = bytes(shape)
def remove_node_by_name(self, node_name):
# remove node in graph
self.graph.node.remove(self.node_name2module[node_name])
@ -167,6 +200,7 @@ class onnxModifier:
# print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info'])
# print(modify_info['added_outputs'])
self.change_batch_size(modify_info['rebatch_info'])
self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr'])
@ -185,10 +219,9 @@ class onnxModifier:
# onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path)
def inference(self, x=None, output_names=None):
def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None):
import onnxruntime as rt
if not x:
input_shape = [1, 3, 224, 224]
x = np.random.randn(*input_shape).astype(np.float32)
if not output_names:
output_name = self.graph.node[-1].output[0]
@ -207,13 +240,7 @@ class onnxModifier:
print(out.shape)
if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\with-added-output-modified_modified_squeezenet1.0-12.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" // TODO: this model is not supported well , but why?
# model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\best.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic():
@ -261,7 +288,6 @@ if __name__ == "__main__":
for initializer in onnx_modifier.initializer:
print(initializer.name)
# print('\nleft nodes:')
# for node in onnx_modifier.graph.node:
# print(node.name)
@ -306,4 +332,22 @@ if __name__ == "__main__":
onnx_modifier.add_outputs(['fire2/squeeze1x1_1'])
# print(onnx_modifier.graph.output)
onnx_modifier.check_and_save_model()
# test_add_output()
# test_add_output()
def test_change_batch_size():
onnx_modifier.inference(input_shape=(1, 3, 640, 640))
print("batch size 1 passed")
onnx_modifier.reload()
onnx_modifier.change_batch_size({'type': 'fixed', 'value': '2'})
onnx_modifier.inference(input_shape=(2, 3, 640, 640))
print("batch size 2 passed")
onnx_modifier.reload()
onnx_modifier.change_batch_size({'type': 'dynamic', 'value': 'dynamic'})
onnx_modifier.inference(input_shape=(6, 3, 640, 640))
print("batch size dynamic passed")
onnx_modifier.check_and_save_model()
# test_change_batch_size()

@ -23,6 +23,7 @@ Currently, the following editing operations are supported:
- [x] Add new model outputs
- [x] Edit the attribute of nodes
- [x] Add new nodes (experimental)
- [x] Change batch size
Here is the [update log](./docs/update_log.md) and [TODO list](./docs/todo_list.md).
@ -189,6 +190,17 @@ The following are some notes for this feature:
7. This feature is experimentally supported now and may be not very robust. So any issues are warmly welcomed if some unexpected results are encountered.
## Change batch size
`onnx-modifier` supports editing batch size now. Both `Dynamic batch size` and `Fixed batch size` modes are supported.
- `Dynamic batch size`: Click the `Dynamic batch size` button, then we get a model which supports dynamic batch size inferece;
- `Fixed batch size`: Input the fixed batch size we want, then we are done;
<img src="./docs/rebatch.gif" style="zoom:75%;" />
Note the differences between `fixed batch size inference` and `dynamic batch size inference`, as [this blog](https://nietras.com/2021/05/24/set-dynamic-batch-size-using-onnx-sharp/) illustrates:
> - When running a model with only fixed dimensions, the ONNX Runtime will prepare and optimize the graph for execution when constructing the Inference Session.
> - when the model has dynamic dimensions like batch size, the ONNX Runtime may instead cache optimized graphs for specific batch sizes when inputs are first encountered for that batch size.
# Sample models
For quick testing, some typical sample models are provided as following. Most of them are from [onnx model zoo](https://github.com/onnx/models)

@ -24,6 +24,7 @@
- [x] 增加模型输出节点
- [x] 编辑节点属性值
- [x] 增加新节点beta
- [x] 修改模型batch size
`onnx-modifier`基于流行的模型可视化工具 [Netron](https://github.com/lutzroeder/netron) 和轻量级Web应用框架 [flask](https://github.com/pallets/flask) 开发。希望它能给社区带来一些贡献~
@ -155,7 +156,12 @@
6. 在当前版本中,如果一个节点的输入/输出是一个列表类型(如`Concat`限制最多显示8个。如果一个节点实际输入/输出小于8个则填写对应数目的输入输出即可多出来的应以`list_custom`开头,它们会在后续处理中自动被忽略。
7. 这个功能还处在开发中可能会不够鲁棒。所以如果大家在实际使用时碰到问题非常欢迎提issue!
## 修改模型batch size
动态batch size和固定batch size均已支持。
- 动态batch size点击`Dynamic batch size`即可;
- 动态bacth size在`Fixed batch size`后方输入框内填入预期的batch size值
<img src="./docs/rebatch.gif" style="zoom:75%;" />
`onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用提issue如果有帮助的话感谢给个:star:~

@ -231,7 +231,8 @@ host.BrowserHost = class {
'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes),
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)),
'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs,
this._view._graph._renameMap, this._view._graph._modelNodeName2State))
this._view._graph._renameMap, this._view._graph._modelNodeName2State)),
'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo)
})
}).then(function (response) {
return response.text();

@ -25,6 +25,8 @@ grapher.Graph = class {
this._addedNode = new Map();
this._addedOutputs = [];
this._reBatchInfo = new Map();
}
get options() {

@ -1146,6 +1146,10 @@ sidebar.ModelSidebar = class {
const separator = this._host.document.createElement('div');
separator.className = 'sidebar-view-separator';
this._elements.push(separator);
this._addHeader('Batch size changing helper');
this._addRebatcher();
}
render() {
@ -1164,6 +1168,38 @@ sidebar.ModelSidebar = class {
this._elements.push(item.render());
}
_addRebatcher() {
this._addButton("Dynamic batch size");
var fixed_batch_size_title = this._host.document.createElement('span');
fixed_batch_size_title.innerHTML = "&nbsp;&nbsp;&nbsp;<strong> or </strong>&nbsp;&nbsp;Fixed batch size&nbsp;&nbsp;&nbsp;";
fixed_batch_size_title.setAttribute('style','font-size:14px');
this._elements.push(fixed_batch_size_title);
var fixed_batch_size_value = this._host.document.createElement("INPUT");
fixed_batch_size_value.setAttribute("type", "text");
fixed_batch_size_value.setAttribute("size", "5");
fixed_batch_size_value.setAttribute("value", 1);
fixed_batch_size_value.addEventListener('input', (e) => {
this._host._view._graph.changeBatchSize('fixed', e.target.value);
});
this._elements.push(fixed_batch_size_value);
}
_addButton(title) {
const buttonElement = this._host.document.createElement('button');
buttonElement.className = 'sidebar-view-button';
buttonElement.innerText = title;
this._elements.push(buttonElement);
if (title === 'Dynamic batch size') {
buttonElement.addEventListener('click', () => {
this._host._view._graph.changeBatchSize("dynamic")
});
}
}
addArgument(name, argument, index, arg_type) {
// const view = new sidebar.ParameterView(this._host, argument);
const view = new sidebar.ParameterView(this._host, argument, arg_type, index, name);

@ -1254,6 +1254,17 @@ view.Graph = class extends grapher.Graph {
this.view._updateGraph();
}
changeBatchSize(type, value) {
if (type === "fixed") {
this._reBatchInfo.set("type", "fixed");
this._reBatchInfo.set("value", value);
}
else { // dynamic
this._reBatchInfo.set("type", "dynamic");
this._reBatchInfo.set("value", "dynamic");
}
}
resetGraph() {
// reset node states
for (const nodeId of this.nodes.keys()) {

Loading…
Cancel
Save