From e5d60110b1b2c88e550448d58ea0bc7596dc9839 Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Wed, 21 Sep 2022 11:23:40 +0800 Subject: [PATCH] fix issue when rebatch info is empty(), as is https://github.com/ZhangGe6/onnx-modifier/issues/18 mentioned --- app.py | 2 +- onnx_modifier.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 8410518..18645a1 100644 --- a/app.py +++ b/app.py @@ -22,7 +22,7 @@ def open_model(): @app.route('/download', methods=['POST']) def modify_and_download_model(): modify_info = request.get_json() - print(modify_info) + # print(modify_info) onnx_modifier.reload() # allow downloading for multiple times onnx_modifier.modify(modify_info) onnx_modifier.check_and_save_model() diff --git a/onnx_modifier.py b/onnx_modifier.py index bca2616..872eca6 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -25,8 +25,10 @@ class onnxModifier: @classmethod def from_name_stream(cls, name, stream): # https://leimao.github.io/blog/ONNX-IO-Stream/ + print("loading model...") stream.seek(0) model_proto = onnx.load_model(stream, onnx.ModelProto, load_external_data=False) + print("load done!") return cls(name, model_proto) def reload(self): @@ -61,6 +63,7 @@ class onnxModifier: self.initializer_name2module[initializer.name] = initializer def change_batch_size(self, rebatch_info): + if not (rebatch_info): return # https://github.com/onnx/onnx/issues/2182#issuecomment-881752539 rebatch_type = rebatch_info['type'] rebatch_value = rebatch_info['value'] @@ -183,9 +186,10 @@ class onnxModifier: self.graph.node.append(node) - def add_outputs(self, added_outputs, node_states): + def add_outputs(self, added_outputs): # https://github.com/onnx/onnx/issues/3277#issuecomment-1050600445 added_output_names = added_outputs.values() + if len(added_output_names) == 0: return # filter out the deleted custom-added outputs value_info_protos = [] shape_info = onnx.shape_inference.infer_shapes(self.model_proto) @@ -205,10 +209,10 @@ class onnxModifier: self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_attr(modify_info['node_changed_attr']) self.add_node(modify_info['added_node_info'], modify_info['node_states']) - self.add_outputs(modify_info['added_outputs'], modify_info['node_states']) - + self.add_outputs(modify_info['added_outputs']) def check_and_save_model(self, save_dir='./modified_onnx'): + print("saving model...") if not os.path.exists(save_dir): os.mkdir(save_dir) save_path = os.path.join(save_dir, 'modified_' + self.model_name) @@ -217,7 +221,8 @@ class onnxModifier: # so `onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted` will be invoked # I turn off the onnx checker as a workaround. # onnx.checker.check_model(self.model_proto) - onnx.save(self.model_proto, save_path) + onnx.save(self.model_proto, save_path) + print("model saved in {} !".format(save_dir)) def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None): import onnxruntime as rt