|
|
|
|
@ -109,11 +109,12 @@ class onnxModifier:
|
|
|
|
|
# self.initializer.remove(self.initializer_name2module[init_name])
|
|
|
|
|
|
|
|
|
|
def modify_node_io_name(self, node_renamed_io):
|
|
|
|
|
# print(node_renamed_io)
|
|
|
|
|
for node_name in node_renamed_io.keys():
|
|
|
|
|
if node_name not in self.node_name2module.keys():
|
|
|
|
|
# custom added nodes or custom added model outputs
|
|
|
|
|
continue
|
|
|
|
|
renamed_ios = node_renamed_io[node_name]
|
|
|
|
|
for src_name, dst_name in renamed_ios.items():
|
|
|
|
|
# print(src_name, dst_name)
|
|
|
|
|
node = self.node_name2module[node_name]
|
|
|
|
|
if node_name in self.graph_input_names:
|
|
|
|
|
node.name = dst_name
|
|
|
|
|
@ -149,16 +150,28 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
self.graph.node.append(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_outputs(self, added_outputs, node_states):
|
|
|
|
|
# https://github.com/onnx/onnx/issues/3277#issuecomment-1050600445
|
|
|
|
|
added_output_names = added_outputs.values()
|
|
|
|
|
# filter out the deleted custom-added outputs
|
|
|
|
|
value_info_protos = []
|
|
|
|
|
shape_info = onnx.shape_inference.infer_shapes(self.model_proto)
|
|
|
|
|
for value_info in shape_info.graph.value_info:
|
|
|
|
|
if value_info.name in added_output_names:
|
|
|
|
|
value_info_protos.append(value_info)
|
|
|
|
|
self.graph.output.extend(value_info_protos)
|
|
|
|
|
|
|
|
|
|
def modify(self, modify_info):
|
|
|
|
|
# print(modify_info['node_states'])
|
|
|
|
|
# print(modify_info['node_renamed_io'])
|
|
|
|
|
# print(modify_info['node_changed_attr'])
|
|
|
|
|
# print(modify_info['added_node_info'])
|
|
|
|
|
# print(modify_info['added_outputs'])
|
|
|
|
|
self.remove_node_by_node_states(modify_info['node_states'])
|
|
|
|
|
self.modify_node_io_name(modify_info['node_renamed_io'])
|
|
|
|
|
self.modify_node_attr(modify_info['node_changed_attr'])
|
|
|
|
|
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
|
|
|
|
|
self.add_outputs(modify_info['added_outputs'], modify_info['node_states'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_save_model(self, save_dir='./modified_onnx'):
|
|
|
|
|
@ -191,15 +204,15 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
# This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506
|
|
|
|
|
out = inference_session.run(None, {input_name: x})[0]
|
|
|
|
|
# print(out)
|
|
|
|
|
|
|
|
|
|
print(out.shape)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx"
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12.onnx"
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\with-added-output-modified_modified_squeezenet1.0-12.onnx"
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" // TODO: this model is not supported well , but why?
|
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx"
|
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
|
|
|
|
|
|
@ -259,7 +272,6 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
# remove_node_by_node_states()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_node_io_name():
|
|
|
|
|
node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}}
|
|
|
|
|
@ -285,34 +297,14 @@ if __name__ == "__main__":
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
# test_change_node_attr()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def debug_remove_node_by_node_states():
|
|
|
|
|
# print(len(onnx_modifier.graph.node))
|
|
|
|
|
# print(len(onnx_modifier.graph.initializer))
|
|
|
|
|
|
|
|
|
|
# print(onnx_modifier.node_name2module.keys())
|
|
|
|
|
# print(onnx_modifier.graph.node)
|
|
|
|
|
# for node in onnx_modifier.graph.node:
|
|
|
|
|
# print(node.name)
|
|
|
|
|
# print(node.input)
|
|
|
|
|
# print(node.output)
|
|
|
|
|
|
|
|
|
|
# print('\noriginal input')
|
|
|
|
|
# for inp in onnx_modifier.graph.input:
|
|
|
|
|
# print(inp.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_states = {'data_0': 'Exist', 'Conv0': 'Exist', 'Relu1': 'Exist', 'MaxPool2': 'Exist', 'Conv3': 'Exist', 'Relu4': 'Exist', 'Conv5': 'Exist', 'Relu6': 'Exist', 'Conv7': 'Exist', 'Relu8': 'Exist', 'Concat9': 'Exist', 'Conv10': 'Exist', 'Relu11': 'Exist', 'Conv12': 'Exist', 'Relu13': 'Exist', 'Conv14': 'Exist', 'Relu15': 'Exist', 'Concat16': 'Exist', 'MaxPool17': 'Exist', 'Conv18': 'Exist', 'Relu19': 'Exist', 'Conv20': 'Exist', 'Relu21': 'Exist', 'Conv22': 'Exist', 'Relu23': 'Exist', 'Concat24': 'Exist', 'Conv25': 'Exist', 'Relu26': 'Exist', 'Conv27': 'Exist', 'Relu28': 'Exist', 'Conv29': 'Exist', 'Relu30': 'Exist', 'Concat31': 'Exist', 'MaxPool32': 'Exist', 'Conv33': 'Exist', 'Relu34': 'Exist', 'Conv35': 'Exist', 'Relu36': 'Exist', 'Conv37': 'Exist', 'Relu38': 'Exist', 'Concat39': 'Exist', 'Conv40': 'Exist', 'Relu41': 'Exist', 'Conv42': 'Exist', 'Relu43': 'Exist', 'Conv44': 'Exist', 'Relu45': 'Exist', 'Concat46': 'Exist', 'Conv47': 'Exist', 'Relu48': 'Exist', 'Conv49': 'Exist', 'Relu50': 'Exist', 'Conv51': 'Exist', 'Relu52': 'Exist', 'Concat53': 'Exist', 'Conv54': 'Exist', 'Relu55': 'Deleted', 'Conv56': 'Deleted', 'Relu57': 'Deleted', 'Conv58': 'Deleted', 'Relu59': 'Deleted', 'Concat60': 'Deleted', 'Dropout61': 'Deleted', 'Conv62': 'Deleted', 'Relu63': 'Deleted', 'GlobalAveragePool64': 'Deleted', 'Softmax65': 'Deleted', 'out_softmaxout_1': 'Deleted'}
|
|
|
|
|
# print('\graph input')
|
|
|
|
|
# for inp in onnx_modifier.graph.input:
|
|
|
|
|
# print(inp.name)
|
|
|
|
|
onnx_modifier.remove_node_by_node_states(node_states)
|
|
|
|
|
|
|
|
|
|
print('\nleft input')
|
|
|
|
|
for inp in onnx_modifier.graph.input:
|
|
|
|
|
print(inp.name)
|
|
|
|
|
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
debug_remove_node_by_node_states()
|
|
|
|
|
def test_inference():
|
|
|
|
|
onnx_modifier.inference()
|
|
|
|
|
test_inference()
|
|
|
|
|
|
|
|
|
|
def test_add_output():
|
|
|
|
|
# print(onnx_modifier.graph.output)
|
|
|
|
|
onnx_modifier.add_outputs(['fire2/squeeze1x1_1'])
|
|
|
|
|
# print(onnx_modifier.graph.output)
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
# test_add_output()
|
|
|
|
|
|