fix issue https://github.com/ZhangGe6/onnx-modifier/issues/12 where extra inputs emerges after deleting nodes

1123
ZhangGe6 4 years ago
parent 7efab04f3a
commit f45da8d4fb

@ -88,6 +88,12 @@ class onnxModifier:
for init_name in self.initializer_name2module.keys():
if not init_name in left_node_inputs:
self.initializer.remove(self.initializer_name2module[init_name])
# remove the (model) inputs related to deleted nodes
# https://github.com/ZhangGe6/onnx-modifier/issues/12
for input_name in self.graph_input_names:
if not input_name in left_node_inputs:
self.graph.input.remove(self.node_name2module[input_name])
# remove the left unused Constant nodes
for left_node in self.graph.node:
@ -102,7 +108,6 @@ class onnxModifier:
self.graph.node.remove(self.node_name2module[left_node.name])
# self.initializer.remove(self.initializer_name2module[init_name])
def modify_node_io_name(self, node_renamed_io):
# print(node_renamed_io)
for node_name in node_renamed_io.keys():
@ -194,7 +199,7 @@ if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_mobilenetv2-7.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
@ -242,16 +247,7 @@ if __name__ == "__main__":
print(node.input, node.output)
for initializer in onnx_modifier.initializer:
print(initializer.name)
# print(onnx_modifier.initializer_name2module.keys())
# for i, k in enumerate(onnx_modifier.initializer_name2module.keys()):
# print("\nremoving", i, k)
# onnx_modifier.graph.initializer.remove(onnx_modifier.initializer_name2module[k])
# print("removed")
# print('\nleft initializers:')
# for initializer in onnx_modifier.model_proto.graph.initializer:
# print(initializer.name)
# print('\nleft nodes:')
# for node in onnx_modifier.graph.node:
@ -269,7 +265,7 @@ if __name__ == "__main__":
node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}}
onnx_modifier.modify_node_io_name(node_rename_io)
onnx_modifier.check_and_save_model()
test_modify_node_io_name()
# test_modify_node_io_name()
def test_add_node():
node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}}
@ -289,5 +285,34 @@ if __name__ == "__main__":
onnx_modifier.check_and_save_model()
# test_change_node_attr()
def debug_remove_node_by_node_states():
# print(len(onnx_modifier.graph.node))
# print(len(onnx_modifier.graph.initializer))
# print(onnx_modifier.node_name2module.keys())
# print(onnx_modifier.graph.node)
# for node in onnx_modifier.graph.node:
# print(node.name)
# print(node.input)
# print(node.output)
# print('\noriginal input')
# for inp in onnx_modifier.graph.input:
# print(inp.name)
node_states = {'data_0': 'Exist', 'Conv0': 'Exist', 'Relu1': 'Exist', 'MaxPool2': 'Exist', 'Conv3': 'Exist', 'Relu4': 'Exist', 'Conv5': 'Exist', 'Relu6': 'Exist', 'Conv7': 'Exist', 'Relu8': 'Exist', 'Concat9': 'Exist', 'Conv10': 'Exist', 'Relu11': 'Exist', 'Conv12': 'Exist', 'Relu13': 'Exist', 'Conv14': 'Exist', 'Relu15': 'Exist', 'Concat16': 'Exist', 'MaxPool17': 'Exist', 'Conv18': 'Exist', 'Relu19': 'Exist', 'Conv20': 'Exist', 'Relu21': 'Exist', 'Conv22': 'Exist', 'Relu23': 'Exist', 'Concat24': 'Exist', 'Conv25': 'Exist', 'Relu26': 'Exist', 'Conv27': 'Exist', 'Relu28': 'Exist', 'Conv29': 'Exist', 'Relu30': 'Exist', 'Concat31': 'Exist', 'MaxPool32': 'Exist', 'Conv33': 'Exist', 'Relu34': 'Exist', 'Conv35': 'Exist', 'Relu36': 'Exist', 'Conv37': 'Exist', 'Relu38': 'Exist', 'Concat39': 'Exist', 'Conv40': 'Exist', 'Relu41': 'Exist', 'Conv42': 'Exist', 'Relu43': 'Exist', 'Conv44': 'Exist', 'Relu45': 'Exist', 'Concat46': 'Exist', 'Conv47': 'Exist', 'Relu48': 'Exist', 'Conv49': 'Exist', 'Relu50': 'Exist', 'Conv51': 'Exist', 'Relu52': 'Exist', 'Concat53': 'Exist', 'Conv54': 'Exist', 'Relu55': 'Deleted', 'Conv56': 'Deleted', 'Relu57': 'Deleted', 'Conv58': 'Deleted', 'Relu59': 'Deleted', 'Concat60': 'Deleted', 'Dropout61': 'Deleted', 'Conv62': 'Deleted', 'Relu63': 'Deleted', 'GlobalAveragePool64': 'Deleted', 'Softmax65': 'Deleted', 'out_softmaxout_1': 'Deleted'}
# print('\graph input')
# for inp in onnx_modifier.graph.input:
# print(inp.name)
onnx_modifier.remove_node_by_node_states(node_states)
print('\nleft input')
for inp in onnx_modifier.graph.input:
print(inp.name)
onnx_modifier.check_and_save_model()
debug_remove_node_by_node_states()
Loading…
Cancel
Save