|
|
|
@ -10,7 +10,8 @@ import struct
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import onnx
|
|
|
|
import onnx
|
|
|
|
from onnx import numpy_helper
|
|
|
|
from onnx import numpy_helper
|
|
|
|
from utils import make_new_node, make_attr_changed_node, parse_tensor
|
|
|
|
from utils import make_new_node, make_attr_changed_node
|
|
|
|
|
|
|
|
from utils import parse_tensor, np2onnxdtype
|
|
|
|
|
|
|
|
|
|
|
|
class onnxModifier:
|
|
|
|
class onnxModifier:
|
|
|
|
def __init__(self, model_name, model_proto):
|
|
|
|
def __init__(self, model_name, model_proto):
|
|
|
|
@ -154,7 +155,7 @@ class onnxModifier:
|
|
|
|
if node.output[i] == src_name:
|
|
|
|
if node.output[i] == src_name:
|
|
|
|
node.output[i] = dst_name
|
|
|
|
node.output[i] = dst_name
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: rename the corresponding initializer and update initializer_name2module
|
|
|
|
# rename the corresponding initializer and update initializer_name2module
|
|
|
|
if src_name in self.initializer_name2module.keys():
|
|
|
|
if src_name in self.initializer_name2module.keys():
|
|
|
|
init = self.initializer_name2module[src_name]
|
|
|
|
init = self.initializer_name2module[src_name]
|
|
|
|
init.name = dst_name
|
|
|
|
init.name = dst_name
|
|
|
|
@ -170,8 +171,9 @@ class onnxModifier:
|
|
|
|
self.graph.node.remove(self.node_name2module[node_name])
|
|
|
|
self.graph.node.remove(self.node_name2module[node_name])
|
|
|
|
self.graph.node.append(attr_changed_node)
|
|
|
|
self.graph.node.append(attr_changed_node)
|
|
|
|
|
|
|
|
|
|
|
|
# update the node_name2module and initializer_name2module
|
|
|
|
# update the node_name2module
|
|
|
|
self.gen_name2module_map()
|
|
|
|
del self.node_name2module[node_name]
|
|
|
|
|
|
|
|
self.node_name2module[node_name] = attr_changed_node
|
|
|
|
|
|
|
|
|
|
|
|
def add_nodes(self, nodes_info, node_states):
|
|
|
|
def add_nodes(self, nodes_info, node_states):
|
|
|
|
for node_info in nodes_info.values():
|
|
|
|
for node_info in nodes_info.values():
|
|
|
|
@ -195,19 +197,36 @@ class onnxModifier:
|
|
|
|
self.graph.output.extend(value_info_protos)
|
|
|
|
self.graph.output.extend(value_info_protos)
|
|
|
|
|
|
|
|
|
|
|
|
def modify_initializer(self, changed_initializer):
|
|
|
|
def modify_initializer(self, changed_initializer):
|
|
|
|
|
|
|
|
# print(changed_initializer)
|
|
|
|
for init_name, meta in changed_initializer.items():
|
|
|
|
for init_name, meta in changed_initializer.items():
|
|
|
|
# https://github.com/onnx/onnx/issues/2978
|
|
|
|
# https://github.com/onnx/onnx/issues/2978
|
|
|
|
init_type, init_val_str = meta
|
|
|
|
init_type, init_val_str = meta
|
|
|
|
|
|
|
|
if init_val_str == "": continue # in case we clear the input
|
|
|
|
# print(init_name, init_type, init_val)
|
|
|
|
# print(init_name, init_type, init_val)
|
|
|
|
init_val = parse_tensor(init_val_str, init_type)
|
|
|
|
init_val = parse_tensor(init_val_str, init_type)
|
|
|
|
# print(init_val)
|
|
|
|
# print(init_val)
|
|
|
|
tensor = numpy_helper.from_array(init_val, init_name)
|
|
|
|
# for primary initilizers
|
|
|
|
self.initializer_name2module[init_name].CopyFrom(tensor)
|
|
|
|
if init_name in self.initializer_name2module.keys():
|
|
|
|
|
|
|
|
tensor = numpy_helper.from_array(init_val, init_name)
|
|
|
|
|
|
|
|
self.initializer_name2module[init_name].CopyFrom(tensor)
|
|
|
|
|
|
|
|
# for custom added initilizers
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
initializer_tensor = onnx.helper.make_tensor(
|
|
|
|
|
|
|
|
name=init_name,
|
|
|
|
|
|
|
|
data_type=np2onnxdtype(init_val.dtype),
|
|
|
|
|
|
|
|
dims=init_val.shape,
|
|
|
|
|
|
|
|
vals=init_val)
|
|
|
|
|
|
|
|
# print(initializer_tensor)
|
|
|
|
|
|
|
|
self.initializer.append(initializer_tensor)
|
|
|
|
|
|
|
|
self.initializer_name2module[init_name] = initializer_tensor
|
|
|
|
|
|
|
|
|
|
|
|
def modify(self, modify_info):
|
|
|
|
def modify(self, modify_info):
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
Some functions, such as modify_initializer(), should be placed
|
|
|
|
1. Some functions, such as modify_initializer(), should be placed
|
|
|
|
before modify_node_io_name(), to avoid name mismatch error.
|
|
|
|
before modify_node_io_name(), to avoid name mismatch error.
|
|
|
|
|
|
|
|
2. add_nodes() should be placed at the first place, otherwise
|
|
|
|
|
|
|
|
remove_node_by_node_states() will delete the initializer of
|
|
|
|
|
|
|
|
newly added nodes mistakenly
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
# print(modify_info['node_states'])
|
|
|
|
# print(modify_info['node_states'])
|
|
|
|
# print(modify_info['node_renamed_io'])
|
|
|
|
# print(modify_info['node_renamed_io'])
|
|
|
|
@ -215,14 +234,14 @@ class onnxModifier:
|
|
|
|
# print(modify_info['added_node_info'])
|
|
|
|
# print(modify_info['added_node_info'])
|
|
|
|
# print(modify_info['added_outputs'])
|
|
|
|
# print(modify_info['added_outputs'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
|
|
|
|
self.modify_initializer(modify_info['changed_initializer'])
|
|
|
|
self.modify_initializer(modify_info['changed_initializer'])
|
|
|
|
self.change_batch_size(modify_info['rebatch_info'])
|
|
|
|
self.change_batch_size(modify_info['rebatch_info'])
|
|
|
|
self.remove_node_by_node_states(modify_info['node_states'])
|
|
|
|
self.remove_node_by_node_states(modify_info['node_states'])
|
|
|
|
self.modify_node_io_name(modify_info['node_renamed_io'])
|
|
|
|
self.modify_node_io_name(modify_info['node_renamed_io'])
|
|
|
|
self.modify_node_attr(modify_info['node_changed_attr'])
|
|
|
|
self.modify_node_attr(modify_info['node_changed_attr'])
|
|
|
|
self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
|
|
|
|
|
|
|
|
self.add_outputs(modify_info['added_outputs'])
|
|
|
|
self.add_outputs(modify_info['added_outputs'])
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_save_model(self, save_dir='./modified_onnx'):
|
|
|
|
def check_and_save_model(self, save_dir='./modified_onnx'):
|
|
|
|
print("saving model...")
|
|
|
|
print("saving model...")
|
|
|
|
if not os.path.exists(save_dir):
|
|
|
|
if not os.path.exists(save_dir):
|
|
|
|
@ -257,7 +276,7 @@ class onnxModifier:
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_EyeNet.onnx"
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx"
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
def explore_basic():
|
|
|
|
def explore_basic():
|
|
|
|
@ -368,10 +387,21 @@ if __name__ == "__main__":
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
# test_change_batch_size()
|
|
|
|
# test_change_batch_size()
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_initializer():
|
|
|
|
def test_modify_primary_initializer():
|
|
|
|
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
|
|
|
|
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
|
|
|
|
onnx_modifier.modify_initializer({'onnx::Reshape_367': ['int64', '[1, 2, 32, 24, 6]']})
|
|
|
|
onnx_modifier.modify_initializer({'onnx::Reshape_367': ['int64', '[1, 2, 32, 24, 6]']})
|
|
|
|
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
|
|
|
|
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
|
|
|
|
test_modify_initializer()
|
|
|
|
# test_modify_primary_initializer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_new_initializer():
|
|
|
|
|
|
|
|
modify_info = {'node_states': {'input': 'Exist', 'Conv_0': 'Exist', 'LeakyRelu_1': 'Exist', 'Conv_2': 'Exist', 'LeakyRelu_3': 'Exist', 'Conv_4': 'Exist', 'LeakyRelu_5': 'Exist', 'Conv_6': 'Exist', 'LeakyRelu_7': 'Exist', 'Conv_8': 'Exist', 'LeakyRelu_9': 'Exist', 'Conv_10': 'Exist', 'Conv_11': 'Exist', 'LeakyRelu_12': 'Exist', 'Conv_13': 'Exist', 'Conv_14': 'Exist', 'LeakyRelu_15': 'Exist', 'Conv_16': 'Exist', 'Concat_17': 'Exist', 'LeakyRelu_18': 'Exist', 'Conv_19': 'Exist', 'Sigmoid_20': 'Exist', 'Mul_22': 'Exist', 'Conv_23': 'Exist', 'LeakyRelu_24': 'Exist', 'Conv_25': 'Exist', 'Conv_26': 'Exist', 'LeakyRelu_27': 'Exist', 'Conv_28': 'Exist', 'Add_29': 'Exist', 'Conv_30': 'Exist', 'Conv_31': 'Exist', 'LeakyRelu_32': 'Exist', 'Conv_33': 'Exist', 'Conv_34': 'Exist', 'LeakyRelu_35': 'Exist', 'Conv_36': 'Exist', 'Concat_37': 'Exist', 'LeakyRelu_38': 'Exist', 'Conv_39': 'Exist',
|
|
|
|
|
|
|
|
'Conv_40': 'Exist', 'LeakyRelu_41': 'Exist', 'Conv_42': 'Exist', 'LeakyRelu_43': 'Exist', 'Conv_44': 'Exist', 'Conv_45': 'Exist', 'LeakyRelu_46': 'Exist', 'Concat_47': 'Exist', 'Reshape_49': 'Exist', 'out_onnx::Transpose_368': 'Exist', 'custom_added_Reshape0': 'Exist', 'out_custom_output_2': 'Exist'}, 'node_renamed_io': {}, 'node_changed_attr': {}, 'added_node_info': {'custom_added_Reshape0': {'properties': {'domain': 'ai.onnx', 'op_type': 'Reshape', 'name': 'custom_added_Reshape0'}, 'attributes': {}, 'inputs': {'data': ['onnx::Transpose_368'], 'shape': ['custom_input_1']}, 'outputs': {'reshaped': ['custom_output_2']}}}, 'added_outputs': {'0': 'custom_output_2'}, 'rebatch_info': {}, 'changed_initializer': {'custom_input_1': ['int64', '[1, 2, 32, 24, 6]']}}
|
|
|
|
|
|
|
|
onnx_modifier.modify(modify_info)
|
|
|
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
|
|
|
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['custom_output_2'])
|
|
|
|
|
|
|
|
print(onnx_modifier.initializer_name2module.keys())
|
|
|
|
|
|
|
|
for initializer in onnx_modifier.initializer:
|
|
|
|
|
|
|
|
print(f"Tensor Name: {initializer.name}, Data Type: {initializer.data_type}, Shape: {initializer.dims}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_modify_new_initializer()
|
|
|
|
|
|
|
|
|