@ -119,21 +119,21 @@ class onnxModifier:
# print('removing node {} ...'.format(node_name))
self . remove_node_by_name ( node_name )
# remove node initializers (parameters), aka, keep and only keep the initializers of left nodes
left_node_inputs = [ ]
for left_node in self . graph . node :
left_node_inputs + = left_node . input
remained_node_inputs = [ ]
for remained_node in self . graph . node :
remained_node_inputs + = remained_node . input
# remove node initializers (parameters), aka, keep and only keep the initializers of remained nodes
for init_name in self . initializer_name2module . keys ( ) :
if not init_name in left _node_inputs:
if not init_name in remained _node_inputs:
self . initializer . remove ( self . initializer_name2module [ init_name ] )
# remove the (model) inputs related to deleted nodes
# https://github.com/ZhangGe6/onnx-modifier/issues/12
for input_name in self . graph_input_names :
if not input_name in left _node_inputs:
if not input_name in remained _node_inputs:
self . graph . input . remove ( self . node_name2module [ input_name ] )
def modify_node_io_name ( self , node_renamed_io ) :
for node_name in node_renamed_io . keys ( ) :
if node_name not in self . node_name2module . keys ( ) :
@ -192,7 +192,9 @@ class onnxModifier:
# filter out the deleted custom-added outputs
value_info_protos = [ ]
shape_info = onnx . shape_inference . infer_shapes ( self . model_proto )
print ( added_output_names )
for value_info in shape_info . graph . value_info :
print ( value_info . name )
if value_info . name in added_output_names :
value_info_protos . append ( value_info )
self . graph . output . extend ( value_info_protos )
@ -221,6 +223,33 @@ class onnxModifier:
self . initializer . append ( initializer_tensor )
self . initializer_name2module [ init_name ] = initializer_tensor
def post_process ( self ) :
def remove_isolated_nodes ( ) :
# remove the remained corresponding isolated nodes, like Constant
remained_node_inputs , remained_node_outputs = [ ] , [ ]
for remained_node in self . graph . node :
remained_node_inputs + = remained_node . input
remained_node_outputs + = remained_node . output
for remained_node in self . graph . node :
# delete the node if it does not serve as the input or output of any other nodes
unused = True
for output in remained_node . output :
if output in remained_node_inputs :
unused = False
break
for input in remained_node . input :
if input in remained_node_outputs :
unused = False
break
if unused :
self . graph . node . remove ( self . node_name2module [ remained_node . name ] )
for inp in remained_node . input :
if inp in self . initializer_name2module . keys ( ) :
self . initializer . remove ( self . initializer_name2module [ inp ] )
remove_isolated_nodes ( )
def modify ( self , modify_info ) :
'''
1. Some functions , such as modify_initializer ( ) , should be placed
@ -242,6 +271,8 @@ class onnxModifier:
self . modify_node_io_name ( modify_info [ ' node_renamed_io ' ] )
self . modify_node_attr ( modify_info [ ' node_changed_attr ' ] )
self . add_outputs ( modify_info [ ' added_outputs ' ] )
self . post_process ( )
def check_and_save_model ( self , save_dir = ' ./modified_onnx ' ) :
print ( " saving model... " )
@ -280,8 +311,9 @@ if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx"
model_path = " C: \\ Users \\ ZhangGe \\ Desktop \\ modified_test_edit_init.onnx "
# model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_test_edit_init.onnx "
# model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx"
model_path = " C: \\ Users \\ ZhangGe \\ Desktop \\ tiny_squeezenet1.0-3.onnx "
onnx_modifier = onnxModifier . from_model_path ( model_path )
def explore_basic ( ) :
@ -366,12 +398,12 @@ if __name__ == "__main__":
def test_inference ( ) :
onnx_modifier . inference ( input_shape = [ 1 , 1 , 192 , 192 ] , output_names = [ " onnx::Transpose_368 " ] )
test_inference ( )
# test_inference( )
def test_add_output ( ) :
# print(onnx_modifier.graph.output)
onnx_modifier . add_outputs ( [' fire2/squeeze1x1_1 ' ] )
# print(onnx_modifier.graph.output )
onnx_modifier . add_outputs ( {' 0 ' : ' out ' } )
print ( onnx_modifier . graph . output )
onnx_modifier . check_and_save_model ( )
# test_add_output()
@ -406,4 +438,10 @@ if __name__ == "__main__":
# print(onnx_modifier.initializer_name2module.keys())
# for initializer in onnx_modifier.initializer:
# print(f"Tensor Name: {initializer.name}, Data Type: {initializer.data_type}, Shape: {initializer.dims}")
# test_modify_new_initializer()
# test_modify_new_initializer()
def test_remove_isolated_nodes ( ) :
modify_info = { ' node_states ' : { ' data_0 ' : ' Exist ' , ' Conv0 ' : ' Exist ' , ' Relu1 ' : ' Exist ' , ' MaxPool2 ' : ' Exist ' , ' Conv3 ' : ' Exist ' , ' Relu4 ' : ' Exist ' , ' Conv5 ' : ' Exist ' , ' Relu6 ' : ' Exist ' , ' Conv7 ' : ' Exist ' , ' Relu8 ' : ' Exist ' , ' Concat9 ' : ' Exist ' , ' Conv10 ' : ' Exist ' } , ' node_renamed_io ' : { ' Conv3 ' : { ' pool1_1 ' : ' conv1_2 ' } , ' MaxPool2 ' : { ' conv1_2 ' : ' conv1 ' } } , ' node_changed_attr ' : { } , ' added_node_info ' : { } , ' added_outputs ' : { } , ' rebatch_info ' : { } , ' changed_initializer ' : { } }
onnx_modifier . modify ( modify_info )
onnx_modifier . check_and_save_model ( )
test_remove_isolated_nodes ( )