|
|
|
|
@ -9,7 +9,8 @@ import copy
|
|
|
|
|
import struct
|
|
|
|
|
import numpy as np
|
|
|
|
|
import onnx
|
|
|
|
|
from utils import make_new_node, make_attr_changed_node
|
|
|
|
|
from onnx import numpy_helper
|
|
|
|
|
from utils import make_new_node, make_attr_changed_node, parse_tensor
|
|
|
|
|
|
|
|
|
|
class onnxModifier:
|
|
|
|
|
def __init__(self, model_name, model_proto):
|
|
|
|
|
@ -100,7 +101,7 @@ class onnxModifier:
|
|
|
|
|
# remove node in graph
|
|
|
|
|
self.graph.node.remove(self.node_name2module[node_name])
|
|
|
|
|
|
|
|
|
|
def remove_output_by_name(self, node_name):
|
|
|
|
|
def remove_model_output_by_name(self, node_name):
|
|
|
|
|
self.graph.output.remove(self.node_name2module[node_name])
|
|
|
|
|
|
|
|
|
|
def remove_node_by_node_states(self, node_states):
|
|
|
|
|
@ -112,7 +113,7 @@ class onnxModifier:
|
|
|
|
|
if node_state == 'Deleted':
|
|
|
|
|
if node_name in self.graph_output_names:
|
|
|
|
|
# print('removing output {} ...'.format(node_name))
|
|
|
|
|
self.remove_output_by_name(node_name)
|
|
|
|
|
self.remove_model_output_by_name(node_name)
|
|
|
|
|
else:
|
|
|
|
|
# print('removing node {} ...'.format(node_name))
|
|
|
|
|
self.remove_node_by_name(node_name)
|
|
|
|
|
@ -131,24 +132,11 @@ class onnxModifier:
|
|
|
|
|
for input_name in self.graph_input_names:
|
|
|
|
|
if not input_name in left_node_inputs:
|
|
|
|
|
self.graph.input.remove(self.node_name2module[input_name])
|
|
|
|
|
|
|
|
|
|
# remove the left unused Constant nodes
|
|
|
|
|
for left_node in self.graph.node:
|
|
|
|
|
if left_node.op_type == "Constant":
|
|
|
|
|
output_deleted = [False] * len(left_node.output)
|
|
|
|
|
for i, output in enumerate(left_node.output):
|
|
|
|
|
if not (output in left_node_inputs):
|
|
|
|
|
output_deleted[i] = True
|
|
|
|
|
|
|
|
|
|
const_node_left_output = [left_node.output[i] for i in range(len(left_node.output)) if not output_deleted[i]]
|
|
|
|
|
if len(const_node_left_output) == 0:
|
|
|
|
|
self.graph.node.remove(self.node_name2module[left_node.name])
|
|
|
|
|
# self.initializer.remove(self.initializer_name2module[init_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def modify_node_io_name(self, node_renamed_io):
|
|
|
|
|
for node_name in node_renamed_io.keys():
|
|
|
|
|
if node_name not in self.node_name2module.keys():
|
|
|
|
|
# custom added nodes or custom added model outputs
|
|
|
|
|
# custom added nodes or custom added model outputs, or the deleted nodes
|
|
|
|
|
continue
|
|
|
|
|
renamed_ios = node_renamed_io[node_name]
|
|
|
|
|
for src_name, dst_name in renamed_ios.items():
|
|
|
|
|
@ -164,7 +152,14 @@ class onnxModifier:
|
|
|
|
|
node.input[i] = dst_name
|
|
|
|
|
for i in range(len(node.output)):
|
|
|
|
|
if node.output[i] == src_name:
|
|
|
|
|
node.output[i] = dst_name
|
|
|
|
|
node.output[i] = dst_name
|
|
|
|
|
|
|
|
|
|
# TODO: rename the corresponding initializer and update initializer_name2module
|
|
|
|
|
if src_name in self.initializer_name2module.keys():
|
|
|
|
|
init = self.initializer_name2module[src_name]
|
|
|
|
|
init.name = dst_name
|
|
|
|
|
self.initializer_name2module[dst_name] = init
|
|
|
|
|
del self.initializer_name2module[src_name]
|
|
|
|
|
|
|
|
|
|
def modify_node_attr(self, node_changed_attr):
|
|
|
|
|
# we achieve it by deleting the original node and make a (copied) new node
|
|
|
|
|
@ -178,7 +173,7 @@ class onnxModifier:
|
|
|
|
|
# update the node_name2module and initializer_name2module
|
|
|
|
|
self.gen_name2module_map()
|
|
|
|
|
|
|
|
|
|
def add_node(self, nodes_info, node_states):
|
|
|
|
|
def add_nodes(self, nodes_info, node_states):
|
|
|
|
|
for node_info in nodes_info.values():
|
|
|
|
|
if node_states[node_info['properties']['name']] == "Deleted":
|
|
|
|
|
continue
|
|
|
|
|
@ -199,17 +194,33 @@ class onnxModifier:
|
|
|
|
|
value_info_protos.append(value_info)
|
|
|
|
|
self.graph.output.extend(value_info_protos)
|
|
|
|
|
|
|
|
|
|
def modify_initializer(self, changed_initializer):
|
|
|
|
|
for init_name, meta in changed_initializer.items():
|
|
|
|
|
# https://github.com/onnx/onnx/issues/2978
|
|
|
|
|
init_type, init_val_str = meta
|
|
|
|
|
# print(init_name, init_type, init_val)
|
|
|
|
|
init_val = parse_tensor(init_val_str, init_type)
|
|
|
|
|
# print(init_val)
|
|
|
|
|
tensor = numpy_helper.from_array(init_val, init_name)
|
|
|
|
|
self.initializer_name2module[init_name].CopyFrom(tensor)
|
|
|
|
|
|
|
|
|
|
def modify(self, modify_info):
|
|
|
|
|
'''
|
|
|
|
|
Some functions, such as modify_initializer(), should be placed
|
|
|
|
|
before modify_node_io_name(), to avoid name mismatch error.
|
|
|
|
|
'''
|
|
|
|
|
# print(modify_info['node_states'])
|
|
|
|
|
# print(modify_info['node_renamed_io'])
|
|
|
|
|
# print(modify_info['node_changed_attr'])
|
|
|
|
|
# print(modify_info['added_node_info'])
|
|
|
|
|
# print(modify_info['added_outputs'])
|
|
|
|
|
|
|
|
|
|
self.modify_initializer(modify_info['changed_initializer'])
|
|
|
|
|
self.change_batch_size(modify_info['rebatch_info'])
|
|
|
|
|
self.remove_node_by_node_states(modify_info['node_states'])
|
|
|
|
|
self.modify_node_io_name(modify_info['node_renamed_io'])
|
|
|
|
|
self.modify_node_attr(modify_info['node_changed_attr'])
|
|
|
|
|
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
|
|
|
|
|
self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
|
|
|
|
|
self.add_outputs(modify_info['added_outputs'])
|
|
|
|
|
|
|
|
|
|
def check_and_save_model(self, save_dir='./modified_onnx'):
|
|
|
|
|
@ -218,7 +229,7 @@ class onnxModifier:
|
|
|
|
|
os.mkdir(save_dir)
|
|
|
|
|
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
|
|
|
|
|
|
|
|
|
|
# adding new node like self.add_node() and self.modify_node_attr() can not guarantee the nodes are topologically sorted
|
|
|
|
|
# adding new node like self.add_nodes() and self.modify_node_attr() can not guarantee the nodes are topologically sorted
|
|
|
|
|
# so `onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted` will be invoked
|
|
|
|
|
# I turn off the onnx checker as a workaround.
|
|
|
|
|
# onnx.checker.check_model(self.model_proto)
|
|
|
|
|
@ -226,7 +237,10 @@ class onnxModifier:
|
|
|
|
|
print("model saved in {} !".format(save_dir))
|
|
|
|
|
|
|
|
|
|
def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None):
|
|
|
|
|
import onnxruntime as rt
|
|
|
|
|
import onnxruntime as rt
|
|
|
|
|
model_proto_bytes = onnx._serialize(self.model_proto)
|
|
|
|
|
inference_session = rt.InferenceSession(model_proto_bytes)
|
|
|
|
|
|
|
|
|
|
if not x:
|
|
|
|
|
x = np.random.randn(*input_shape).astype(np.float32)
|
|
|
|
|
if not output_names:
|
|
|
|
|
@ -234,20 +248,16 @@ class onnxModifier:
|
|
|
|
|
# output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.INT64, shape=[])
|
|
|
|
|
output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape=[])
|
|
|
|
|
self.graph.output.append(output_value_info)
|
|
|
|
|
|
|
|
|
|
model_proto_bytes = onnx._serialize(self.model_proto)
|
|
|
|
|
inference_session = rt.InferenceSession(model_proto_bytes)
|
|
|
|
|
output_names = [inference_session.get_outputs()[0].name]
|
|
|
|
|
|
|
|
|
|
input_name = inference_session.get_inputs()[0].name
|
|
|
|
|
output_name = inference_session.get_outputs()[0].name
|
|
|
|
|
|
|
|
|
|
# This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506
|
|
|
|
|
out = inference_session.run(None, {input_name: x})[0]
|
|
|
|
|
out = inference_session.run(output_names, {input_name: x})[0]
|
|
|
|
|
print(out.shape)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_EyeNet.onnx"
|
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
|
|
|
|
|
|
def explore_basic():
|
|
|
|
|
@ -315,7 +325,7 @@ if __name__ == "__main__":
|
|
|
|
|
def test_add_node():
|
|
|
|
|
node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}}
|
|
|
|
|
|
|
|
|
|
onnx_modifier.add_node(node_info)
|
|
|
|
|
onnx_modifier.add_nodes(node_info)
|
|
|
|
|
|
|
|
|
|
onnx_modifier.inference()
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
@ -357,4 +367,11 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
# test_change_batch_size()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_initializer():
|
|
|
|
|
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
|
|
|
|
|
onnx_modifier.modify_initializer({'onnx::Reshape_367': ['int64', '[1, 2, 32, 24, 6]']})
|
|
|
|
|
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['onnx::Transpose_368'])
|
|
|
|
|
test_modify_initializer()
|
|
|
|
|
|
|
|
|
|
|