ZhangGe6 3 years ago
parent 3dc195bf53
commit d289f00f96

@ -22,7 +22,7 @@ def open_model():
@app.route('/download', methods=['POST'])
def modify_and_download_model():
modify_info = request.get_json()
# print(modify_info)
print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times
onnx_modifier.modify(modify_info)
onnx_modifier.check_and_save_model()

@ -9,7 +9,8 @@ import copy
import struct
import numpy as np
import onnx
from utils import make_new_node, make_attr_changed_node
from onnx import numpy_helper
from utils import make_new_node, make_attr_changed_node, parse_tensor
class onnxModifier:
def __init__(self, model_name, model_proto):
@ -100,7 +101,7 @@ class onnxModifier:
# remove node in graph
self.graph.node.remove(self.node_name2module[node_name])
def remove_output_by_name(self, node_name):
def remove_model_output_by_name(self, node_name):
self.graph.output.remove(self.node_name2module[node_name])
def remove_node_by_node_states(self, node_states):
@ -112,7 +113,7 @@ class onnxModifier:
if node_state == 'Deleted':
if node_name in self.graph_output_names:
# print('removing output {} ...'.format(node_name))
self.remove_output_by_name(node_name)
self.remove_model_output_by_name(node_name)
else:
# print('removing node {} ...'.format(node_name))
self.remove_node_by_name(node_name)
@ -131,24 +132,11 @@ class onnxModifier:
for input_name in self.graph_input_names:
if not input_name in left_node_inputs:
self.graph.input.remove(self.node_name2module[input_name])
# remove the left unused Constant nodes
for left_node in self.graph.node:
if left_node.op_type == "Constant":
output_deleted = [False] * len(left_node.output)
for i, output in enumerate(left_node.output):
if not (output in left_node_inputs):
output_deleted[i] = True
const_node_left_output = [left_node.output[i] for i in range(len(left_node.output)) if not output_deleted[i]]
if len(const_node_left_output) == 0:
self.graph.node.remove(self.node_name2module[left_node.name])
# self.initializer.remove(self.initializer_name2module[init_name])
def modify_node_io_name(self, node_renamed_io):
for node_name in node_renamed_io.keys():
if node_name not in self.node_name2module.keys():
# custom added nodes or custom added model outputs
# custom added nodes or custom added model outputs, or the deleted nodes
continue
renamed_ios = node_renamed_io[node_name]
for src_name, dst_name in renamed_ios.items():
@ -164,7 +152,14 @@ class onnxModifier:
node.input[i] = dst_name
for i in range(len(node.output)):
if node.output[i] == src_name:
node.output[i] = dst_name
node.output[i] = dst_name
# TODO: rename the corresponding initializer and update initializer_name2module
if src_name in self.initializer_name2module.keys():
init = self.initializer_name2module[src_name]
init.name = dst_name
self.initializer_name2module[dst_name] = init
del self.initializer_name2module[src_name]
def modify_node_attr(self, node_changed_attr):
# we achieve it by deleting the original node and make a (copied) new node
@ -178,7 +173,7 @@ class onnxModifier:
# update the node_name2module and initializer_name2module
self.gen_name2module_map()
def add_node(self, nodes_info, node_states):
def add_nodes(self, nodes_info, node_states):
for node_info in nodes_info.values():
if node_states[node_info['properties']['name']] == "Deleted":
continue
@ -199,17 +194,33 @@ class onnxModifier:
value_info_protos.append(value_info)
self.graph.output.extend(value_info_protos)
def modify_initializer(self, changed_initializer):
for init_name, meta in changed_initializer.items():
# https://github.com/onnx/onnx/issues/2978
init_type, init_val_str = meta
# print(init_name, init_type, init_val)
init_val = parse_tensor(init_val_str, init_type)
# print(init_val)
tensor = numpy_helper.from_array(init_val, init_name)
self.initializer_name2module[init_name].CopyFrom(tensor)
def modify(self, modify_info):
'''
Some functions, such as modify_initializer(), should be placed
before modify_node_io_name(), to avoid name mismatch error.
'''
# print(modify_info['node_states'])
# print(modify_info['node_renamed_io'])
# print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info'])
# print(modify_info['added_outputs'])
self.modify_initializer(modify_info['changed_initializer'])
self.change_batch_size(modify_info['rebatch_info'])
self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr'])
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
self.add_outputs(modify_info['added_outputs'])
def check_and_save_model(self, save_dir='./modified_onnx'):
@ -218,7 +229,7 @@ class onnxModifier:
os.mkdir(save_dir)
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
# adding new node like self.add_node() and self.modify_node_attr() can not guarantee the nodes are topologically sorted
# adding new node like self.add_nodes() and self.modify_node_attr() can not guarantee the nodes are topologically sorted
# so `onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted` will be invoked
# I turn off the onnx checker as a workaround.
# onnx.checker.check_model(self.model_proto)
@ -226,7 +237,10 @@ class onnxModifier:
print("model saved in {} !".format(save_dir))
def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None):
import onnxruntime as rt
import onnxruntime as rt
model_proto_bytes = onnx._serialize(self.model_proto)
inference_session = rt.InferenceSession(model_proto_bytes)
if not x:
x = np.random.randn(*input_shape).astype(np.float32)
if not output_names:
@ -234,20 +248,16 @@ class onnxModifier:
# output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.INT64, shape=[])
output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape=[])
self.graph.output.append(output_value_info)
model_proto_bytes = onnx._serialize(self.model_proto)
inference_session = rt.InferenceSession(model_proto_bytes)
output_names = [inference_session.get_outputs()[0].name]
input_name = inference_session.get_inputs()[0].name
output_name = inference_session.get_outputs()[0].name
# This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506
out = inference_session.run(None, {input_name: x})[0]
out = inference_session.run(output_names, {input_name: x})[0]
print(out.shape)
if __name__ == "__main__":
model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_EyeNet.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic():
@ -315,7 +325,7 @@ if __name__ == "__main__":
def test_add_node():
node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}}
onnx_modifier.add_node(node_info)
onnx_modifier.add_nodes(node_info)
onnx_modifier.inference()
onnx_modifier.check_and_save_model()
@ -357,4 +367,11 @@ if __name__ == "__main__":
onnx_modifier.check_and_save_model()
# test_change_batch_size()
def test_modify_initializer():
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
onnx_modifier.modify_initializer({'onnx::Reshape_367': ['int64', '[1, 2, 32, 24, 6]']})
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
test_modify_initializer()

@ -232,7 +232,8 @@ host.BrowserHost = class {
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)),
'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs,
this._view._graph._renameMap, this._view._graph._modelNodeName2State)),
'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo)
'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo),
'changed_initializer' : this.mapToObjectRec(this._view._graph._initializerEditInfo)
})
}).then(function (response) {
return response.text();

@ -444,6 +444,7 @@ onnx.Graph = class {
// console.log(graph)
for (const initializer of graph.initializer) {
const tensor = context.tensor(initializer.name);
// console.log(initializer) // type: TensorProto
tensor.initializer = new onnx.Tensor(context, initializer, 'Initializer');
}
for (const sparse_initializer of graph.sparse_initializer) {
@ -566,6 +567,9 @@ onnx.Graph = class {
arg_list = [this._context.argument(arg_name)]
}
for (var arg of arg_list) {
arg.is_custom_added = true;
}
inputs.push(new onnx.Parameter(input.name, arg_list));
}
@ -596,6 +600,10 @@ onnx.Graph = class {
arg_list = [this._context.argument(arg_name)]
}
for (var arg of arg_list) {
arg.is_custom_added = true;
}
outputs.push(new onnx.Parameter(output.name, arg_list));
}
@ -677,7 +685,8 @@ onnx.Argument = class {
this._annotation = annotation;
this._description = description || '';
this.original_name = original_name || name
this.original_name = original_name || name;
this.is_custom_added = false;
}
@ -1766,6 +1775,7 @@ onnx.GraphContext = class {
argument(name, original_name) {
const tensor = this.tensor(name);
// console.log(tensor)
const type = tensor.initializer ? tensor.initializer.type : tensor.type || null;
return new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description, original_name);

@ -27,6 +27,7 @@ grapher.Graph = class {
this._addedOutputs = [];
this._reBatchInfo = new Map();
this._initializerEditInfo = new Map();
}
get options() {

@ -24,6 +24,7 @@
.sidebar-view-item-value-line-link { padding: 4px 6px 4px 6px; cursor: default; }
.sidebar-view-item-value-line-link:hover { text-decoration: underline; }
.sidebar-view-item-value-line-border { padding: 4px 6px 4px 6px; border-top: 1px solid rgba(27, 31, 35, 0.05); }
.sidebar-view-item-value-border { padding: 4px 6px 4px 6px;}
.sidebar-view-item-value-line-content { white-space: pre; word-wrap: normal; overflow: auto; display: block; }
.sidebar-view-item-value-expander { font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace; float: right; color: #aaa; cursor: pointer; user-select: none; -webkit-user-select: none; -moz-user-select: none; padding: 4px 6px 4px 6px; }
.sidebar-view-item-value-expander:hover { color: #000; }

@ -132,6 +132,7 @@ sidebar.NodeSidebar = class {
this._attributes = [];
this._inputs = [];
this._outputs = [];
// console.log(node) // onnx.Node
if (node.type) {
let showDocumentation = null;
@ -305,7 +306,7 @@ sidebar.NodeSidebar = class {
}
_addInput(name, input, param_idx) {
// console.log(input)
// console.log(input) // type: onnx.Parameter
if (input.arguments.length > 0) {
const view = new sidebar.ParameterView(this._host, input, 'input', param_idx, this._modelNodeName);
view.on('export-tensor', (sender, tensor) => {
@ -875,7 +876,9 @@ sidebar.ArgumentView = class {
const quantization = argument.quantization;
const type = argument.type;
const location = this._argument.location !== undefined;
if (type || initializer || quantization || location) {
const is_custom_added = argument.is_custom_added;
// console.log(argument)
if (type || initializer || quantization || location || is_custom_added) {
this._expander = this._host.document.createElement('div');
this._expander.className = 'sidebar-view-item-value-expander';
this._expander.innerText = '+';
@ -949,6 +952,7 @@ sidebar.ArgumentView = class {
this._expander.innerText = '-';
const initializer = this._argument.initializer;
// console.log(this._argument, initializer) // type: onnx.Argument, onnx.Tensor
if (this._hasId && this._hasKind) {
const kindLine = this._host.document.createElement('div');
kindLine.className = 'sidebar-view-item-value-line-border';
@ -998,8 +1002,29 @@ sidebar.ArgumentView = class {
location.innerHTML = 'location: ' + '<b>' + this._argument.location + '</b>';
this._element.appendChild(location);
}
if (initializer || this._argument.is_custom_added) {
const editInitializer = this._host.document.createElement('div');
editInitializer.className = 'sidebar-view-item-value-line-border';
editInitializer.innerHTML = 'If this is an initializer, you can input new value for it here:';
this._element.appendChild(editInitializer);
var inputInitializer = document.createElement("INPUT");
inputInitializer.setAttribute("type", "text");
inputInitializer.setAttribute("size", "42");
inputInitializer.addEventListener('input', (e) => {
// console.log(e.target.value)
this._host._view._graph.changeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, this._argument.type._dataType, e.target.value);
});
this._element.appendChild(inputInitializer);
}
if (initializer) {
// to edit the existed initializer
const origInitLine = this._host.document.createElement('div');
origInitLine.className = 'sidebar-view-item-value-line-border';
origInitLine.innerHTML = 'original initializer value:';
this._element.appendChild(origInitLine);
const contentLine = this._host.document.createElement('pre');
const valueLine = this._host.document.createElement('div');
try {
@ -1016,10 +1041,11 @@ sidebar.ArgumentView = class {
this._element.appendChild(this._saveButton);
}
valueLine.className = 'sidebar-view-item-value-line-border';
// valueLine.className = 'sidebar-view-item-value-line-border';
valueLine.className = 'sidebar-view-item-value-border'
contentLine.innerHTML = state || initializer.toString();
console.log(initializer)
console.log(state, initializer.toString())
// console.log(initializer)
// console.log(state, initializer.toString())
}
catch (err) {
contentLine.innerHTML = err.toString();

@ -1121,6 +1121,7 @@ view.Graph = class extends grapher.Graph {
}
add(graph) {
// console.log(graph) // type: onnx.Graph
const clusters = new Set();
const clusterParentMap = new Map();
const groups = graph.groups;
@ -1145,6 +1146,7 @@ view.Graph = class extends grapher.Graph {
}
for (var node of graph.nodes) {
// console.log(node) // type: onnx.Node
var viewNode = this.createNode(node);
var inputs = node.inputs;
@ -1377,8 +1379,31 @@ view.Graph = class extends grapher.Graph {
this.view._updateGraph()
}
getOriginalName(param_type, modelNodeName, param_index, arg_index) {
if (param_type == 'model_input') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name
}
if (param_type == 'model_output') {
modelNodeName = 'out_' + modelNodeName
// console.log(modelNodeName)
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name
// console.log(orig_arg_name)
}
if (param_type == 'input') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
if (param_type == 'output') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
changeNodeInputOutput(modelNodeName, parameterName, param_type, param_index, arg_index, targetValue, orig_arg_name) {
return orig_arg_name
}
changeNodeInputOutput(modelNodeName, parameterName, param_type, param_index, arg_index, targetValue) {
// changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) {
if (this._addedNode.has(modelNodeName)) { // for custom added node
if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) {
@ -1393,26 +1418,8 @@ view.Graph = class extends grapher.Graph {
// console.log(this._addedNode)
else { // for the nodes in the original model
if (param_type == 'model_input') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name
}
if (param_type == 'model_output') {
modelNodeName = 'out_' + modelNodeName
// console.log(modelNodeName)
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].original_name
// console.log(orig_arg_name)
}
if (param_type == 'input') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
if (param_type == 'output') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index)
console.log(orig_arg_name)
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
@ -1422,9 +1429,32 @@ view.Graph = class extends grapher.Graph {
}
this.view._updateGraph()
}
changeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) {
// changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) {
// if (this._addedNode.has(modelNodeName)) { // for custom added node
// if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) {
// this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue
// }
// if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) {
// this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue
// }
// // this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph
// }
// // console.log(this._addedNode)
// else { // for the nodes in the original model
var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index)
this._initializerEditInfo.set(orig_arg_name, [type, targetValue]);
// console.log(this._renameMap)
// }
console.log(this._initializerEditInfo)
// this.view._updateGraph()
}
build(document, origin) {

@ -1 +1,2 @@
from .make_nodes import *
from .make_nodes import *
from .parse_tools import *

@ -0,0 +1,70 @@
import numpy as np
def parse_value(value_str, value_type):
if value_type.startswith('int'):
return int(value_str)
elif value_type.startswith('float'):
return float(value_str)
else:
raise RuntimeError("type {} is not considered in current version. \
You can kindly report an issue for this problem. Thanks!".format(value_type))
def parse_tensor(tensor_str, tensor_type):
def extract_val():
num_str = ""
while (len(stk) > 0) and (type(stk[-1]) == str and ord('0') <= ord(stk[-1]) <= ord('9') or stk[-1] in ['+', '-', '.', 'e', 'E']):
num_str = stk.pop() + num_str
if len(num_str) > 0:
return parse_value(num_str, tensor_type)
else:
return None
tensor_str = tensor_str.replace(" ", "")
stk = []
for i, c in enumerate(tensor_str): # '[' ',' ']' '.' '-' or value
if c == ",":
ext_val = extract_val()
if ext_val is not None: stk.append(ext_val)
elif c == "]":
ext_val = extract_val()
if ext_val is not None: stk.append(ext_val)
arr = []
while stk[-1] != '[':
arr.append(stk.pop())
stk.pop() # the left [
arr.reverse()
stk.append(arr)
else:
stk.append(c)
val = stk[0]
# wrap with numpy with the specific data type
if tensor_type == "int64":
return np.array(val, dtype=np.int64)
elif tensor_type == "int32":
return np.array(val, dtype=np.int32)
elif tensor_type == "int8":
return np.array(val, dtype=np.int8)
elif tensor_type == "float64":
return np.array(val, dtype=np.float64)
elif tensor_type == "float32":
return np.array(val, dtype=np.float32)
else:
raise RuntimeError("type {} is not considered in current version. \
You can kindly report an issue for this problem. Thanks!".format(tensor_type))
if __name__ == "__main__":
# tensor_str = "1"
# tensor_str = "[1, 2, 3]"
tensor_str = "[[10, 2.3, 3],[1, 2e6, 3]]"
val = parse_tensor(tensor_str, "float32")
print(type(val), val)
tensor_str = "[[10, 2, 3],[1, 2, 3]]"
val = parse_tensor(tensor_str, "int64")
print(type(val), val)
Loading…
Cancel
Save