|
|
|
|
@ -5,7 +5,10 @@
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import copy
|
|
|
|
|
import numpy as np
|
|
|
|
|
import onnx
|
|
|
|
|
import onnxruntime as rt
|
|
|
|
|
from utils import make_node
|
|
|
|
|
|
|
|
|
|
class onnxModifier:
|
|
|
|
|
def __init__(self, model_name, model_proto):
|
|
|
|
|
@ -63,6 +66,9 @@ class onnxModifier:
|
|
|
|
|
def remove_node_by_node_states(self, node_states):
|
|
|
|
|
# remove node from graph
|
|
|
|
|
for node_name, node_state in node_states.items():
|
|
|
|
|
if not (node_name in self.node_name2module):
|
|
|
|
|
# for custom added node here
|
|
|
|
|
continue
|
|
|
|
|
if node_state == 'Deleted':
|
|
|
|
|
if node_name in self.graph_output_names:
|
|
|
|
|
# print('removing output {} ...'.format(node_name))
|
|
|
|
|
@ -98,31 +104,15 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
def add_node(self, nodes_info):
|
|
|
|
|
for node_info in nodes_info.values():
|
|
|
|
|
name = node_info['properties']['name']
|
|
|
|
|
op_type = node_info['properties']['op_type']
|
|
|
|
|
attributes = node_info['attributes']
|
|
|
|
|
|
|
|
|
|
inputs = []
|
|
|
|
|
for key in node_info['inputs'].keys():
|
|
|
|
|
inputs += node_info['inputs'][key]
|
|
|
|
|
outputs = []
|
|
|
|
|
for key in node_info['outputs'].keys():
|
|
|
|
|
outputs += node_info['outputs'][key]
|
|
|
|
|
|
|
|
|
|
node = onnx.helper.make_node(
|
|
|
|
|
op_type=op_type,
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
outputs=outputs,
|
|
|
|
|
name=name,
|
|
|
|
|
**attributes
|
|
|
|
|
)
|
|
|
|
|
node = make_node(node_info)
|
|
|
|
|
# print(node)
|
|
|
|
|
|
|
|
|
|
self.graph.node.append(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def modify(self, modify_info):
|
|
|
|
|
print(modify_info['node_renamed_io'])
|
|
|
|
|
# print(modify_info['node_renamed_io'])
|
|
|
|
|
print(modify_info['added_node_info'])
|
|
|
|
|
self.remove_node_by_node_states(modify_info['node_states'])
|
|
|
|
|
self.modify_node_io_name(modify_info['node_renamed_io'])
|
|
|
|
|
self.add_node(modify_info['added_node_info'])
|
|
|
|
|
@ -135,11 +125,29 @@ class onnxModifier:
|
|
|
|
|
onnx.checker.check_model(self.model_proto)
|
|
|
|
|
onnx.save(self.model_proto, save_path)
|
|
|
|
|
|
|
|
|
|
def inference(self):
|
|
|
|
|
# model_proto_bytes = onnx._serialize(model_proto_from_stream)
|
|
|
|
|
# inference_session = rt.InferenceSession(model_proto_bytes)
|
|
|
|
|
pass
|
|
|
|
|
def inference(self, x=None, output_names=None):
|
|
|
|
|
if not x:
|
|
|
|
|
input_shape = [1, 3, 224, 224]
|
|
|
|
|
x = np.random.randn(*input_shape).astype(np.float32)
|
|
|
|
|
if not output_names:
|
|
|
|
|
output_name = self.graph.node[-1].output[0]
|
|
|
|
|
# output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.INT64, shape=[])
|
|
|
|
|
output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape=[])
|
|
|
|
|
self.graph.output.append(output_value_info)
|
|
|
|
|
|
|
|
|
|
model_proto_bytes = onnx._serialize(self.model_proto)
|
|
|
|
|
inference_session = rt.InferenceSession(model_proto_bytes)
|
|
|
|
|
|
|
|
|
|
input_name = inference_session.get_inputs()[0].name
|
|
|
|
|
output_name = inference_session.get_outputs()[0].name
|
|
|
|
|
# print(input_name)
|
|
|
|
|
# print(output_name)
|
|
|
|
|
|
|
|
|
|
# This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506
|
|
|
|
|
out = inference_session.run(None, {input_name: x})[0]
|
|
|
|
|
|
|
|
|
|
# print(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
|
|
|
|
|
@ -147,7 +155,25 @@ if __name__ == "__main__":
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx"
|
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
|
|
|
|
|
|
def explore_basic():
|
|
|
|
|
print(type(onnx_modifier.model_proto.graph.initializer))
|
|
|
|
|
print(dir(onnx_modifier.model_proto.graph.initializer))
|
|
|
|
|
|
|
|
|
|
print(len(onnx_modifier.model_proto.graph.node))
|
|
|
|
|
print(len(onnx_modifier.model_proto.graph.initializer))
|
|
|
|
|
|
|
|
|
|
for node in onnx_modifier.model_proto.graph.node:
|
|
|
|
|
print(node.name)
|
|
|
|
|
print(node.input)
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
# for initializer in onnx_modifier.model_proto.graph.initializer:
|
|
|
|
|
# print(initializer.name)
|
|
|
|
|
# print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale'])
|
|
|
|
|
pass
|
|
|
|
|
# explore_basic()
|
|
|
|
|
|
|
|
|
|
def remove_node_by_node_states():
|
|
|
|
|
print(len(onnx_modifier.graph.node))
|
|
|
|
|
print(len(onnx_modifier.graph.initializer))
|
|
|
|
|
@ -184,24 +210,7 @@ if __name__ == "__main__":
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
# remove_node_by_node_states()
|
|
|
|
|
|
|
|
|
|
def explore_basic():
|
|
|
|
|
print(type(onnx_modifier.model_proto.graph.initializer))
|
|
|
|
|
print(dir(onnx_modifier.model_proto.graph.initializer))
|
|
|
|
|
|
|
|
|
|
print(len(onnx_modifier.model_proto.graph.node))
|
|
|
|
|
print(len(onnx_modifier.model_proto.graph.initializer))
|
|
|
|
|
|
|
|
|
|
for node in onnx_modifier.model_proto.graph.node:
|
|
|
|
|
print(node.name)
|
|
|
|
|
print(node.input)
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
# for initializer in onnx_modifier.model_proto.graph.initializer:
|
|
|
|
|
# print(initializer.name)
|
|
|
|
|
# print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale'])
|
|
|
|
|
pass
|
|
|
|
|
# explore_basic()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_node_io_name():
|
|
|
|
|
node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}}
|
|
|
|
|
onnx_modifier.modify_node_io_name(node_rename_io)
|
|
|
|
|
@ -209,9 +218,11 @@ if __name__ == "__main__":
|
|
|
|
|
# test_modify_node_io_name()
|
|
|
|
|
|
|
|
|
|
def test_add_node():
|
|
|
|
|
node_info = {'properties': {'domain': 'ai.onnx', 'op_type': 'Abs', 'name': 'custom_added_Abs0'}, 'attributes': {}, 'inputs': {'X': ['custom_input_0']}, 'outputs': {'Y': ['custom_output_1']}}
|
|
|
|
|
node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}}
|
|
|
|
|
|
|
|
|
|
onnx_modifier.add_node(node_info)
|
|
|
|
|
|
|
|
|
|
onnx_modifier.inference()
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
|
|
|
|
|
test_add_node()
|
|
|
|
|
|