You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
# https://leimao.github.io/blog/ONNX-Python-API/
|
|
# https://github.com/saurabh-shandilya/onnx-utils
|
|
import io
|
|
import os
|
|
from platform import node
|
|
import onnx
|
|
|
|
class onnxModifier:
|
|
def __init__(self, model_name, model_proto):
|
|
self.model_name = model_name
|
|
self.model_proto = model_proto
|
|
self.graph = self.model_proto.graph
|
|
|
|
self.gen_node_name2module_map()
|
|
|
|
def gen_node_name2module_map(self):
|
|
self.node_name2module = dict()
|
|
node_idx = 0
|
|
for node in self.graph.input:
|
|
node_idx += 1
|
|
self.node_name2module[node.name] = node
|
|
|
|
for node in self.graph.node:
|
|
if node.name == '':
|
|
node.name = str(node.op_type) + str(node_idx)
|
|
node_idx += 1
|
|
self.node_name2module[node.name] = node
|
|
|
|
for node in self.graph.output:
|
|
node_idx += 1
|
|
self.node_name2module[node.name] = node
|
|
self.graph_output_names = [node.name for node in self.graph.output]
|
|
# print(self.node_name2module.keys())
|
|
|
|
@classmethod
|
|
def from_model_path(cls, model_path):
|
|
model_name = os.path.basename(model_path)
|
|
model_proto = onnx.load(model_path)
|
|
return cls(model_name, model_proto)
|
|
|
|
@classmethod
|
|
def from_name_stream(cls, name, stream):
|
|
# https://leimao.github.io/blog/ONNX-IO-Stream/
|
|
|
|
stream.seek(0)
|
|
model_proto = onnx.load_model(stream, onnx.ModelProto)
|
|
return cls(name, model_proto)
|
|
|
|
def remove_node_by_name(self, node_name):
|
|
self.graph.node.remove(self.node_name2module[node_name])
|
|
|
|
def remove_output_by_name(self, node_name):
|
|
self.graph.output.remove(self.node_name2module[node_name])
|
|
|
|
def remove_node_by_node_states(self, node_states):
|
|
for node_name, node_state in node_states.items():
|
|
if node_state == 'Deleted':
|
|
if node_name in self.graph_output_names:
|
|
self.remove_output_by_name(node_name)
|
|
else:
|
|
self.remove_node_by_name(node_name)
|
|
|
|
def check_and_save_model(self, save_dir='./res_onnx'):
|
|
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
|
|
onnx.checker.check_model(self.model_proto)
|
|
onnx.save(self.model_proto, save_path)
|
|
|
|
def inference(self):
|
|
# model_proto_bytes = onnx._serialize(model_proto_from_stream)
|
|
# inference_session = rt.InferenceSession(model_proto_bytes)
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
onnx_modifier.remove_node_by_name('Softmax_nc_rename_64')
|
|
onnx_modifier.remove_output_by_name('softmaxout_1')
|
|
# onnx_modifier.graph.output.remove(onnx_modifier.node_name2module['softmaxout_1'])
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
# print(onnx_modifier.graph.input)
|
|
# print(onnx_modifier.graph.output)
|
|
# print(onnx_modifier.node_name2module['Softmax_nc_rename_64'])
|
|
# print(onnx_modifier.node_name2module['softmaxout_1'])
|
|
# onnx_modifier.remove_node_by_name('softmaxout_1')
|
|
|
|
|
|
# for node in onnx_modifier.graph.output:
|
|
# print(node.name)
|
|
|
|
|
|
|
|
|
|
|