|
|
|
@ -15,7 +15,6 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
|
|
self.gen_name2module_map()
|
|
|
|
self.gen_name2module_map()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_name2module_map(self):
|
|
|
|
def gen_name2module_map(self):
|
|
|
|
# node name => node
|
|
|
|
# node name => node
|
|
|
|
self.node_name2module = dict()
|
|
|
|
self.node_name2module = dict()
|
|
|
|
@ -27,8 +26,8 @@ class onnxModifier:
|
|
|
|
self.node_name2module[node.name] = node
|
|
|
|
self.node_name2module[node.name] = node
|
|
|
|
|
|
|
|
|
|
|
|
for out in self.graph.output:
|
|
|
|
for out in self.graph.output:
|
|
|
|
self.node_name2module[out.name] = out
|
|
|
|
self.node_name2module["out_" + out.name] = out # add `out_` in case the output has the same name with the last node
|
|
|
|
self.graph_output_names = [out.name for out in self.graph.output]
|
|
|
|
self.graph_output_names = ["out_" + out.name for out in self.graph.output]
|
|
|
|
# print(self.node_name2module.keys())
|
|
|
|
# print(self.node_name2module.keys())
|
|
|
|
|
|
|
|
|
|
|
|
# initializer name => initializer
|
|
|
|
# initializer name => initializer
|
|
|
|
@ -79,7 +78,7 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_save_model(self, save_dir='./res_onnx'):
|
|
|
|
def check_and_save_model(self, save_dir='./res_onnx'):
|
|
|
|
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
|
|
|
|
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
|
|
|
|
onnx.checker.check_model(self.model_proto)
|
|
|
|
# onnx.checker.check_model(self.model_proto)
|
|
|
|
onnx.save(self.model_proto, save_path)
|
|
|
|
onnx.save(self.model_proto, save_path)
|
|
|
|
|
|
|
|
|
|
|
|
def inference(self):
|
|
|
|
def inference(self):
|
|
|
|
@ -89,8 +88,9 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
|
|
|
|
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
|
|
|
|
|
|
|
|
model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
onnx_modifier = onnxModifier.from_model_path(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
def remove_node_by_node_states():
|
|
|
|
def remove_node_by_node_states():
|
|
|
|
@ -102,8 +102,9 @@ if __name__ == "__main__":
|
|
|
|
'fire3/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_16': 'Deleted', 'MaxPool_nc_rename_17': 'Deleted', 'pool3_1_QuantizeLinear':
|
|
|
|
'fire3/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_16': 'Deleted', 'MaxPool_nc_rename_17': 'Deleted', 'pool3_1_QuantizeLinear':
|
|
|
|
'Deleted', 'Conv_nc_rename_18_quant': 'Deleted', 'Conv_nc_rename_20_quant': 'Deleted', 'Conv_nc_rename_22_quant': 'Deleted', 'fire4/expand1x1_2_DequantizeLinear': 'Deleted', 'fire4/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_24': 'Deleted', 'fire4/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_25_quant': 'Deleted', 'Conv_nc_rename_27_quant': 'Deleted', 'Conv_nc_rename_29_quant': 'Deleted', 'fire5/expand1x1_2_DequantizeLinear': 'Deleted', 'fire5/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_31': 'Deleted', 'MaxPool_nc_rename_32': 'Deleted', 'pool5_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_33_quant': 'Deleted', 'Conv_nc_rename_35_quant': 'Deleted', 'Conv_nc_rename_37_quant': 'Deleted', 'fire6/expand1x1_2_DequantizeLinear': 'Deleted', 'fire6/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_39': 'Deleted', 'fire6/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_40_quant': 'Deleted', 'Conv_nc_rename_42_quant': 'Deleted', 'Conv_nc_rename_44_quant': 'Deleted', 'fire7/expand1x1_2_DequantizeLinear': 'Deleted', 'fire7/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_46': 'Deleted', 'fire7/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_47_quant': 'Deleted', 'Conv_nc_rename_49_quant': 'Deleted', 'Conv_nc_rename_51_quant': 'Deleted', 'fire8/expand1x1_2_DequantizeLinear': 'Deleted', 'fire8/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_53': 'Deleted', 'fire8/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_54_quant': 'Deleted', 'Conv_nc_rename_56_quant': 'Deleted', 'Conv_nc_rename_58_quant': 'Deleted', 'fire9/expand1x1_2_DequantizeLinear': 'Deleted', 'fire9/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_60': 'Deleted', 'fire9/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_61_quant': 'Deleted', 'GlobalAveragePool_nc_rename_63_quant': 'Deleted', 'pool10_1_DequantizeLinear': 'Deleted', 'Softmax_nc_rename_64': 'Deleted', 'softmaxout_1': 'Deleted'}
|
|
|
|
'Deleted', 'Conv_nc_rename_18_quant': 'Deleted', 'Conv_nc_rename_20_quant': 'Deleted', 'Conv_nc_rename_22_quant': 'Deleted', 'fire4/expand1x1_2_DequantizeLinear': 'Deleted', 'fire4/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_24': 'Deleted', 'fire4/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_25_quant': 'Deleted', 'Conv_nc_rename_27_quant': 'Deleted', 'Conv_nc_rename_29_quant': 'Deleted', 'fire5/expand1x1_2_DequantizeLinear': 'Deleted', 'fire5/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_31': 'Deleted', 'MaxPool_nc_rename_32': 'Deleted', 'pool5_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_33_quant': 'Deleted', 'Conv_nc_rename_35_quant': 'Deleted', 'Conv_nc_rename_37_quant': 'Deleted', 'fire6/expand1x1_2_DequantizeLinear': 'Deleted', 'fire6/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_39': 'Deleted', 'fire6/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_40_quant': 'Deleted', 'Conv_nc_rename_42_quant': 'Deleted', 'Conv_nc_rename_44_quant': 'Deleted', 'fire7/expand1x1_2_DequantizeLinear': 'Deleted', 'fire7/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_46': 'Deleted', 'fire7/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_47_quant': 'Deleted', 'Conv_nc_rename_49_quant': 'Deleted', 'Conv_nc_rename_51_quant': 'Deleted', 'fire8/expand1x1_2_DequantizeLinear': 'Deleted', 'fire8/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_53': 'Deleted', 'fire8/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_54_quant': 'Deleted', 'Conv_nc_rename_56_quant': 'Deleted', 'Conv_nc_rename_58_quant': 'Deleted', 'fire9/expand1x1_2_DequantizeLinear': 'Deleted', 'fire9/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_60': 'Deleted', 'fire9/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_61_quant': 'Deleted', 'GlobalAveragePool_nc_rename_63_quant': 'Deleted', 'pool10_1_DequantizeLinear': 'Deleted', 'Softmax_nc_rename_64': 'Deleted', 'softmaxout_1': 'Deleted'}
|
|
|
|
|
|
|
|
|
|
|
|
# node_states = node_states_quant
|
|
|
|
|
|
|
|
node_states = node_states_fp
|
|
|
|
node_states = node_states_quant
|
|
|
|
|
|
|
|
# node_states = node_states_fp
|
|
|
|
# print('\graph input')
|
|
|
|
# print('\graph input')
|
|
|
|
# for inp in onnx_modifier.graph.input:
|
|
|
|
# for inp in onnx_modifier.graph.input:
|
|
|
|
# print(inp.name)
|
|
|
|
# print(inp.name)
|
|
|
|
|