ZhangGe6 4 years ago
parent d289f00f96
commit 5aabb9f0ce

@ -10,7 +10,8 @@ import struct
import numpy as np import numpy as np
import onnx import onnx
from onnx import numpy_helper from onnx import numpy_helper
from utils import make_new_node, make_attr_changed_node, parse_tensor from utils import make_new_node, make_attr_changed_node
from utils import parse_tensor, np2onnxdtype
class onnxModifier: class onnxModifier:
def __init__(self, model_name, model_proto): def __init__(self, model_name, model_proto):
@ -154,7 +155,7 @@ class onnxModifier:
if node.output[i] == src_name: if node.output[i] == src_name:
node.output[i] = dst_name node.output[i] = dst_name
# TODO: rename the corresponding initializer and update initializer_name2module # rename the corresponding initializer and update initializer_name2module
if src_name in self.initializer_name2module.keys(): if src_name in self.initializer_name2module.keys():
init = self.initializer_name2module[src_name] init = self.initializer_name2module[src_name]
init.name = dst_name init.name = dst_name
@ -170,8 +171,9 @@ class onnxModifier:
self.graph.node.remove(self.node_name2module[node_name]) self.graph.node.remove(self.node_name2module[node_name])
self.graph.node.append(attr_changed_node) self.graph.node.append(attr_changed_node)
# update the node_name2module and initializer_name2module # update the node_name2module
self.gen_name2module_map() del self.node_name2module[node_name]
self.node_name2module[node_name] = attr_changed_node
def add_nodes(self, nodes_info, node_states): def add_nodes(self, nodes_info, node_states):
for node_info in nodes_info.values(): for node_info in nodes_info.values():
@ -195,19 +197,36 @@ class onnxModifier:
self.graph.output.extend(value_info_protos) self.graph.output.extend(value_info_protos)
def modify_initializer(self, changed_initializer): def modify_initializer(self, changed_initializer):
# print(changed_initializer)
for init_name, meta in changed_initializer.items(): for init_name, meta in changed_initializer.items():
# https://github.com/onnx/onnx/issues/2978 # https://github.com/onnx/onnx/issues/2978
init_type, init_val_str = meta init_type, init_val_str = meta
if init_val_str == "": continue # in case we clear the input
# print(init_name, init_type, init_val) # print(init_name, init_type, init_val)
init_val = parse_tensor(init_val_str, init_type) init_val = parse_tensor(init_val_str, init_type)
# print(init_val) # print(init_val)
tensor = numpy_helper.from_array(init_val, init_name) # for primary initilizers
self.initializer_name2module[init_name].CopyFrom(tensor) if init_name in self.initializer_name2module.keys():
tensor = numpy_helper.from_array(init_val, init_name)
self.initializer_name2module[init_name].CopyFrom(tensor)
# for custom added initilizers
else:
initializer_tensor = onnx.helper.make_tensor(
name=init_name,
data_type=np2onnxdtype(init_val.dtype),
dims=init_val.shape,
vals=init_val)
# print(initializer_tensor)
self.initializer.append(initializer_tensor)
self.initializer_name2module[init_name] = initializer_tensor
def modify(self, modify_info): def modify(self, modify_info):
''' '''
Some functions, such as modify_initializer(), should be placed 1. Some functions, such as modify_initializer(), should be placed
before modify_node_io_name(), to avoid name mismatch error. before modify_node_io_name(), to avoid name mismatch error.
2. add_nodes() should be placed at the first place, otherwise
remove_node_by_node_states() will delete the initializer of
newly added nodes mistakenly
''' '''
# print(modify_info['node_states']) # print(modify_info['node_states'])
# print(modify_info['node_renamed_io']) # print(modify_info['node_renamed_io'])
@ -215,12 +234,12 @@ class onnxModifier:
# print(modify_info['added_node_info']) # print(modify_info['added_node_info'])
# print(modify_info['added_outputs']) # print(modify_info['added_outputs'])
self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
self.modify_initializer(modify_info['changed_initializer']) self.modify_initializer(modify_info['changed_initializer'])
self.change_batch_size(modify_info['rebatch_info']) self.change_batch_size(modify_info['rebatch_info'])
self.remove_node_by_node_states(modify_info['node_states']) self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr']) self.modify_node_attr(modify_info['node_changed_attr'])
self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
self.add_outputs(modify_info['added_outputs']) self.add_outputs(modify_info['added_outputs'])
def check_and_save_model(self, save_dir='./modified_onnx'): def check_and_save_model(self, save_dir='./modified_onnx'):
@ -257,7 +276,7 @@ class onnxModifier:
if __name__ == "__main__": if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_EyeNet.onnx" model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path) onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic(): def explore_basic():
@ -368,10 +387,21 @@ if __name__ == "__main__":
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
# test_change_batch_size() # test_change_batch_size()
def test_modify_initializer(): def test_modify_primary_initializer():
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368']) onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
onnx_modifier.modify_initializer({'onnx::Reshape_367': ['int64', '[1, 2, 32, 24, 6]']}) onnx_modifier.modify_initializer({'onnx::Reshape_367': ['int64', '[1, 2, 32, 24, 6]']})
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368']) onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
test_modify_initializer() # test_modify_primary_initializer()
def test_modify_new_initializer():
modify_info = {'node_states': {'input': 'Exist', 'Conv_0': 'Exist', 'LeakyRelu_1': 'Exist', 'Conv_2': 'Exist', 'LeakyRelu_3': 'Exist', 'Conv_4': 'Exist', 'LeakyRelu_5': 'Exist', 'Conv_6': 'Exist', 'LeakyRelu_7': 'Exist', 'Conv_8': 'Exist', 'LeakyRelu_9': 'Exist', 'Conv_10': 'Exist', 'Conv_11': 'Exist', 'LeakyRelu_12': 'Exist', 'Conv_13': 'Exist', 'Conv_14': 'Exist', 'LeakyRelu_15': 'Exist', 'Conv_16': 'Exist', 'Concat_17': 'Exist', 'LeakyRelu_18': 'Exist', 'Conv_19': 'Exist', 'Sigmoid_20': 'Exist', 'Mul_22': 'Exist', 'Conv_23': 'Exist', 'LeakyRelu_24': 'Exist', 'Conv_25': 'Exist', 'Conv_26': 'Exist', 'LeakyRelu_27': 'Exist', 'Conv_28': 'Exist', 'Add_29': 'Exist', 'Conv_30': 'Exist', 'Conv_31': 'Exist', 'LeakyRelu_32': 'Exist', 'Conv_33': 'Exist', 'Conv_34': 'Exist', 'LeakyRelu_35': 'Exist', 'Conv_36': 'Exist', 'Concat_37': 'Exist', 'LeakyRelu_38': 'Exist', 'Conv_39': 'Exist',
'Conv_40': 'Exist', 'LeakyRelu_41': 'Exist', 'Conv_42': 'Exist', 'LeakyRelu_43': 'Exist', 'Conv_44': 'Exist', 'Conv_45': 'Exist', 'LeakyRelu_46': 'Exist', 'Concat_47': 'Exist', 'Reshape_49': 'Exist', 'out_onnx::Transpose_368': 'Exist', 'custom_added_Reshape0': 'Exist', 'out_custom_output_2': 'Exist'}, 'node_renamed_io': {}, 'node_changed_attr': {}, 'added_node_info': {'custom_added_Reshape0': {'properties': {'domain': 'ai.onnx', 'op_type': 'Reshape', 'name': 'custom_added_Reshape0'}, 'attributes': {}, 'inputs': {'data': ['onnx::Transpose_368'], 'shape': ['custom_input_1']}, 'outputs': {'reshaped': ['custom_output_2']}}}, 'added_outputs': {'0': 'custom_output_2'}, 'rebatch_info': {}, 'changed_initializer': {'custom_input_1': ['int64', '[1, 2, 32, 24, 6]']}}
onnx_modifier.modify(modify_info)
onnx_modifier.check_and_save_model()
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['custom_output_2'])
print(onnx_modifier.initializer_name2module.keys())
for initializer in onnx_modifier.initializer:
print(f"Tensor Name: {initializer.name}, Data Type: {initializer.data_type}, Shape: {initializer.dims}")
test_modify_new_initializer()

@ -1003,20 +1003,74 @@ sidebar.ArgumentView = class {
this._element.appendChild(location); this._element.appendChild(location);
} }
if (initializer || this._argument.is_custom_added) { if (initializer) {
const editInitializer = this._host.document.createElement('div'); const editInitializerVal = this._host.document.createElement('div');
editInitializer.className = 'sidebar-view-item-value-line-border'; editInitializerVal.className = 'sidebar-view-item-value-line-border';
editInitializer.innerHTML = 'If this is an initializer, you can input new value for it here:'; editInitializerVal.innerHTML = 'This is an initializer, you can input a new value for it here:';
this._element.appendChild(editInitializer); this._element.appendChild(editInitializerVal);
var inputInitializer = document.createElement("INPUT"); var inputInitializerVal = document.createElement("INPUT");
inputInitializer.setAttribute("type", "text"); inputInitializerVal.setAttribute("type", "text");
inputInitializer.setAttribute("size", "42"); inputInitializerVal.setAttribute("size", "42");
inputInitializer.addEventListener('input', (e) => { // reload the last value
var orig_arg_name = this._host._view._graph.getOriginalName(this._param_type, this._modelNodeName, this._param_index, this._arg_index)
if (this._host._view._graph._initializerEditInfo.get(orig_arg_name)) {
// [type, value]
inputInitializerVal.setAttribute("value", this._host._view._graph._initializerEditInfo.get(orig_arg_name)[1]);
}
inputInitializerVal.addEventListener('input', (e) => {
// console.log(e.target.value) // console.log(e.target.value)
this._host._view._graph.changeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, this._argument.type._dataType, e.target.value); this._host._view._graph.changeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, this._argument.type._dataType, e.target.value);
}); });
this._element.appendChild(inputInitializer); this._element.appendChild(inputInitializerVal);
}
if (this._argument.is_custom_added) {
var new_init_val = "", new_init_type = "";
// ====== input value ======>
const editInitializerVal = this._host.document.createElement('div');
editInitializerVal.className = 'sidebar-view-item-value-line-border';
editInitializerVal.innerHTML = 'If this is an initializer, you can input new value for it here:';
this._element.appendChild(editInitializerVal);
var inputInitializerVal = document.createElement("INPUT");
inputInitializerVal.setAttribute("type", "text");
inputInitializerVal.setAttribute("size", "42");
// this._element.appendChild(inputInitializerVal);
inputInitializerVal.addEventListener('input', (e) => {
// console.log(e.target.value)
new_init_val = e.target.value;
this._host._view._graph.changeAddedNodeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, new_init_type, new_init_val);
});
this._element.appendChild(inputInitializerVal);
// <====== input value ======
// ====== input type ======>
const editInitializerType = this._host.document.createElement('div');
editInitializerType.className = 'sidebar-view-item-value-line-border';
editInitializerType.innerHTML = 'and input its type for it here (see properties->type->? for more info):';
this._element.appendChild(editInitializerType);
var inputInitializerType = document.createElement("INPUT");
inputInitializerType.setAttribute("type", "text");
inputInitializerType.setAttribute("size", "42");
var arg_name = this._host._view._graph._addedNode.get(this._modelNodeName).inputs.get(this._parameterName)[this._arg_index]
if (this._host._view._graph._initializerEditInfo.get(arg_name)) {
// [type, value]
inputInitializerType.setAttribute("value", this._host._view._graph._initializerEditInfo.get(arg_name)[0]);
inputInitializerVal.setAttribute("value", this._host._view._graph._initializerEditInfo.get(arg_name)[1]);
}
inputInitializerType.addEventListener('input', (e) => {
// console.log(e.target.value)
new_init_type = e.target.value;
this._host._view._graph.changeAddedNodeInitializer(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, new_init_type, new_init_val);
});
this._element.appendChild(inputInitializerType);
// <====== input type ======
} }
if (initializer) { if (initializer) {

@ -585,6 +585,7 @@ view.View = class {
viewGraph._addedNode = this.lastViewGraph._addedNode; viewGraph._addedNode = this.lastViewGraph._addedNode;
viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey; viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey;
viewGraph._addedOutputs = this.lastViewGraph._addedOutputs; viewGraph._addedOutputs = this.lastViewGraph._addedOutputs;
viewGraph._initializerEditInfo = this.lastViewGraph._initializerEditInfo;
// console.log(viewGraph._renameMap); // console.log(viewGraph._renameMap);
// console.log(viewGraph._modelNodeName2State) // console.log(viewGraph._modelNodeName2State)
@ -1033,12 +1034,6 @@ view.View = class {
} }
} }
} }
reloadLastLocation() {
const container = this._getElementById('graph');
}
}; };
view.Graph = class extends grapher.Graph { view.Graph = class extends grapher.Graph {
@ -1432,29 +1427,23 @@ view.Graph = class extends grapher.Graph {
} }
changeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) { changeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) {
// changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { if (this._addedNode.has(modelNodeName)) { // for custom added node
// if (this._addedNode.has(modelNodeName)) { // for custom added node }
// if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { // console.log(this._addedNode)
// this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue
// }
// if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) {
// this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue
// }
// // this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph
// }
// // console.log(this._addedNode)
// else { // for the nodes in the original model else { // for the nodes in the original model
var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index)
this._initializerEditInfo.set(orig_arg_name, [type, targetValue]);
}
var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index) this.view._updateGraph()
this._initializerEditInfo.set(orig_arg_name, [type, targetValue]); }
// console.log(this._renameMap)
// }
console.log(this._initializerEditInfo)
// this.view._updateGraph() changeAddedNodeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) {
} var arg_name = this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index]
this._initializerEditInfo.set(arg_name, [type, targetValue]);
this.view._updateGraph()
}
build(document, origin) { build(document, origin) {

@ -1,4 +1,6 @@
import numpy as np import numpy as np
from typing import cast
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
def parse_value(value_str, value_type): def parse_value(value_str, value_type):
if value_type.startswith('int'): if value_type.startswith('int'):
@ -9,6 +11,7 @@ def parse_value(value_str, value_type):
raise RuntimeError("type {} is not considered in current version. \ raise RuntimeError("type {} is not considered in current version. \
You can kindly report an issue for this problem. Thanks!".format(value_type)) You can kindly report an issue for this problem. Thanks!".format(value_type))
# parse numpy values from string
def parse_tensor(tensor_str, tensor_type): def parse_tensor(tensor_str, tensor_type):
def extract_val(): def extract_val():
num_str = "" num_str = ""
@ -22,7 +25,7 @@ def parse_tensor(tensor_str, tensor_type):
tensor_str = tensor_str.replace(" ", "") tensor_str = tensor_str.replace(" ", "")
stk = [] stk = []
for i, c in enumerate(tensor_str): # '[' ',' ']' '.' '-' or value for c in tensor_str: # '[' ',' ']' '.' '-' or value
if c == ",": if c == ",":
ext_val = extract_val() ext_val = extract_val()
if ext_val is not None: stk.append(ext_val) if ext_val is not None: stk.append(ext_val)
@ -56,6 +59,11 @@ def parse_tensor(tensor_str, tensor_type):
raise RuntimeError("type {} is not considered in current version. \ raise RuntimeError("type {} is not considered in current version. \
You can kindly report an issue for this problem. Thanks!".format(tensor_type)) You can kindly report an issue for this problem. Thanks!".format(tensor_type))
# map np datatype to onnx datatype
# https://github.com/onnx/onnx/blob/8669fad0247799f4d8683550eec749974b4f5338/onnx/helper.py#L1177
def np2onnxdtype(np_dtype):
return cast(int, NP_TYPE_TO_TENSOR_TYPE[np_dtype])
if __name__ == "__main__": if __name__ == "__main__":
# tensor_str = "1" # tensor_str = "1"
# tensor_str = "[1, 2, 3]" # tensor_str = "[1, 2, 3]"

Loading…
Cancel
Save