fix bug for the output has the same name with last node

1123
ZhangGe6 4 years ago
parent 046f342c2e
commit 4ff46e6acf

@ -22,7 +22,7 @@ def return_file():
def modify_and_download_model(): def modify_and_download_model():
node_states = json.loads(request.get_json()) node_states = json.loads(request.get_json())
print(node_states) # print(node_states)
onnx_modifier.remove_node_by_node_states(node_states) onnx_modifier.remove_node_by_node_states(node_states)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()

@ -14,7 +14,6 @@ class onnxModifier:
self.initializer = self.model_proto.graph.initializer self.initializer = self.model_proto.graph.initializer
self.gen_name2module_map() self.gen_name2module_map()
def gen_name2module_map(self): def gen_name2module_map(self):
# node name => node # node name => node
@ -27,8 +26,8 @@ class onnxModifier:
self.node_name2module[node.name] = node self.node_name2module[node.name] = node
for out in self.graph.output: for out in self.graph.output:
self.node_name2module[out.name] = out self.node_name2module["out_" + out.name] = out # add `out_` in case the output has the same name with the last node
self.graph_output_names = [out.name for out in self.graph.output] self.graph_output_names = ["out_" + out.name for out in self.graph.output]
# print(self.node_name2module.keys()) # print(self.node_name2module.keys())
# initializer name => initializer # initializer name => initializer
@ -79,7 +78,7 @@ class onnxModifier:
def check_and_save_model(self, save_dir='./res_onnx'): def check_and_save_model(self, save_dir='./res_onnx'):
save_path = os.path.join(save_dir, 'modified_' + self.model_name) save_path = os.path.join(save_dir, 'modified_' + self.model_name)
onnx.checker.check_model(self.model_proto) # onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path) onnx.save(self.model_proto, save_path)
def inference(self): def inference(self):
@ -89,8 +88,9 @@ class onnxModifier:
if __name__ == "__main__": if __name__ == "__main__":
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path) onnx_modifier = onnxModifier.from_model_path(model_path)
def remove_node_by_node_states(): def remove_node_by_node_states():
@ -101,9 +101,10 @@ if __name__ == "__main__":
node_states_quant = {'data_0': 'Exist', 'data_0_QuantizeLinear': 'Exist', 'Conv_nc_rename_0_quant': 'Exist', 'MaxPool_nc_rename_2_quant': 'Exist', 'Conv_nc_rename_3_quant': 'Deleted', 'Conv_nc_rename_5_quant': 'Deleted', 'Conv_nc_rename_7_quant': 'Deleted', 'fire2/expand1x1_2_DequantizeLinear': 'Deleted', 'fire2/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_9': 'Deleted', 'fire2/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_10_quant': 'Deleted', 'Conv_nc_rename_12_quant': 'Deleted', 'Conv_nc_rename_14_quant': 'Deleted', 'fire3/expand1x1_2_DequantizeLinear': 'Deleted', node_states_quant = {'data_0': 'Exist', 'data_0_QuantizeLinear': 'Exist', 'Conv_nc_rename_0_quant': 'Exist', 'MaxPool_nc_rename_2_quant': 'Exist', 'Conv_nc_rename_3_quant': 'Deleted', 'Conv_nc_rename_5_quant': 'Deleted', 'Conv_nc_rename_7_quant': 'Deleted', 'fire2/expand1x1_2_DequantizeLinear': 'Deleted', 'fire2/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_9': 'Deleted', 'fire2/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_10_quant': 'Deleted', 'Conv_nc_rename_12_quant': 'Deleted', 'Conv_nc_rename_14_quant': 'Deleted', 'fire3/expand1x1_2_DequantizeLinear': 'Deleted',
'fire3/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_16': 'Deleted', 'MaxPool_nc_rename_17': 'Deleted', 'pool3_1_QuantizeLinear': 'fire3/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_16': 'Deleted', 'MaxPool_nc_rename_17': 'Deleted', 'pool3_1_QuantizeLinear':
'Deleted', 'Conv_nc_rename_18_quant': 'Deleted', 'Conv_nc_rename_20_quant': 'Deleted', 'Conv_nc_rename_22_quant': 'Deleted', 'fire4/expand1x1_2_DequantizeLinear': 'Deleted', 'fire4/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_24': 'Deleted', 'fire4/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_25_quant': 'Deleted', 'Conv_nc_rename_27_quant': 'Deleted', 'Conv_nc_rename_29_quant': 'Deleted', 'fire5/expand1x1_2_DequantizeLinear': 'Deleted', 'fire5/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_31': 'Deleted', 'MaxPool_nc_rename_32': 'Deleted', 'pool5_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_33_quant': 'Deleted', 'Conv_nc_rename_35_quant': 'Deleted', 'Conv_nc_rename_37_quant': 'Deleted', 'fire6/expand1x1_2_DequantizeLinear': 'Deleted', 'fire6/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_39': 'Deleted', 'fire6/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_40_quant': 'Deleted', 'Conv_nc_rename_42_quant': 'Deleted', 'Conv_nc_rename_44_quant': 'Deleted', 'fire7/expand1x1_2_DequantizeLinear': 'Deleted', 'fire7/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_46': 'Deleted', 'fire7/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_47_quant': 'Deleted', 'Conv_nc_rename_49_quant': 'Deleted', 'Conv_nc_rename_51_quant': 'Deleted', 'fire8/expand1x1_2_DequantizeLinear': 'Deleted', 'fire8/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_53': 'Deleted', 'fire8/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_54_quant': 'Deleted', 'Conv_nc_rename_56_quant': 'Deleted', 'Conv_nc_rename_58_quant': 'Deleted', 'fire9/expand1x1_2_DequantizeLinear': 'Deleted', 'fire9/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_60': 'Deleted', 'fire9/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_61_quant': 'Deleted', 'GlobalAveragePool_nc_rename_63_quant': 'Deleted', 'pool10_1_DequantizeLinear': 'Deleted', 'Softmax_nc_rename_64': 'Deleted', 'softmaxout_1': 'Deleted'} 'Deleted', 'Conv_nc_rename_18_quant': 'Deleted', 'Conv_nc_rename_20_quant': 'Deleted', 'Conv_nc_rename_22_quant': 'Deleted', 'fire4/expand1x1_2_DequantizeLinear': 'Deleted', 'fire4/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_24': 'Deleted', 'fire4/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_25_quant': 'Deleted', 'Conv_nc_rename_27_quant': 'Deleted', 'Conv_nc_rename_29_quant': 'Deleted', 'fire5/expand1x1_2_DequantizeLinear': 'Deleted', 'fire5/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_31': 'Deleted', 'MaxPool_nc_rename_32': 'Deleted', 'pool5_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_33_quant': 'Deleted', 'Conv_nc_rename_35_quant': 'Deleted', 'Conv_nc_rename_37_quant': 'Deleted', 'fire6/expand1x1_2_DequantizeLinear': 'Deleted', 'fire6/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_39': 'Deleted', 'fire6/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_40_quant': 'Deleted', 'Conv_nc_rename_42_quant': 'Deleted', 'Conv_nc_rename_44_quant': 'Deleted', 'fire7/expand1x1_2_DequantizeLinear': 'Deleted', 'fire7/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_46': 'Deleted', 'fire7/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_47_quant': 'Deleted', 'Conv_nc_rename_49_quant': 'Deleted', 'Conv_nc_rename_51_quant': 'Deleted', 'fire8/expand1x1_2_DequantizeLinear': 'Deleted', 'fire8/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_53': 'Deleted', 'fire8/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_54_quant': 'Deleted', 'Conv_nc_rename_56_quant': 'Deleted', 'Conv_nc_rename_58_quant': 'Deleted', 'fire9/expand1x1_2_DequantizeLinear': 'Deleted', 'fire9/expand3x3_2_DequantizeLinear': 'Deleted', 'Concat_nc_rename_60': 'Deleted', 'fire9/concat_1_QuantizeLinear': 'Deleted', 'Conv_nc_rename_61_quant': 'Deleted', 'GlobalAveragePool_nc_rename_63_quant': 'Deleted', 'pool10_1_DequantizeLinear': 'Deleted', 'Softmax_nc_rename_64': 'Deleted', 'softmaxout_1': 'Deleted'}
# node_states = node_states_quant node_states = node_states_quant
node_states = node_states_fp # node_states = node_states_fp
# print('\graph input') # print('\graph input')
# for inp in onnx_modifier.graph.input: # for inp in onnx_modifier.graph.input:
# print(inp.name) # print(inp.name)

@ -899,9 +899,10 @@ view.Graph = class extends grapher.Graph {
} }
createOutput(output) { createOutput(output) {
const value = new view.Output(this, output); var modelNodeName = "out_" + output.name; // in case the output has the same name with the last node
const value = new view.Output(this, output, modelNodeName);
// value.name = (this._nodeKey++).toString(); // value.name = (this._nodeKey++).toString();
value.name = output.name; // output nodes should have name value.name = "out_" + output.name; // output nodes should have name
this.setNode(value); this.setNode(value);
return value; return value;
} }
@ -1281,13 +1282,13 @@ view.Input = class extends grapher.Node {
view.Output = class extends grapher.Node { view.Output = class extends grapher.Node {
constructor(context, value) { constructor(context, value, modelNodeName) {
super(); super();
this.context = context; this.context = context;
this.value = value; this.value = value;
const types = value.arguments.map((argument) => argument.type || '').join('\n'); const types = value.arguments.map((argument) => argument.type || '').join('\n');
let name = value.name || ''; let name = value.name || '';
this.modelNodeName = value.name this.modelNodeName = modelNodeName
if (name.length > 16) { if (name.length > 16) {
name = name.split('/').pop(); name = name.split('/').pop();
} }

Loading…
Cancel
Save