fix issue when rebatch info is empty(), as is https://github.com/ZhangGe6/onnx-modifier/issues/18 mentioned

1123
ZhangGe6 3 years ago
parent 7e8e7e9ffa
commit e5d60110b1

@ -22,7 +22,7 @@ def open_model():
@app.route('/download', methods=['POST'])
def modify_and_download_model():
modify_info = request.get_json()
print(modify_info)
# print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times
onnx_modifier.modify(modify_info)
onnx_modifier.check_and_save_model()

@ -25,8 +25,10 @@ class onnxModifier:
@classmethod
def from_name_stream(cls, name, stream):
# https://leimao.github.io/blog/ONNX-IO-Stream/
print("loading model...")
stream.seek(0)
model_proto = onnx.load_model(stream, onnx.ModelProto, load_external_data=False)
print("load done!")
return cls(name, model_proto)
def reload(self):
@ -61,6 +63,7 @@ class onnxModifier:
self.initializer_name2module[initializer.name] = initializer
def change_batch_size(self, rebatch_info):
if not (rebatch_info): return
# https://github.com/onnx/onnx/issues/2182#issuecomment-881752539
rebatch_type = rebatch_info['type']
rebatch_value = rebatch_info['value']
@ -183,9 +186,10 @@ class onnxModifier:
self.graph.node.append(node)
def add_outputs(self, added_outputs, node_states):
def add_outputs(self, added_outputs):
# https://github.com/onnx/onnx/issues/3277#issuecomment-1050600445
added_output_names = added_outputs.values()
if len(added_output_names) == 0: return
# filter out the deleted custom-added outputs
value_info_protos = []
shape_info = onnx.shape_inference.infer_shapes(self.model_proto)
@ -205,10 +209,10 @@ class onnxModifier:
self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr'])
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
self.add_outputs(modify_info['added_outputs'], modify_info['node_states'])
self.add_outputs(modify_info['added_outputs'])
def check_and_save_model(self, save_dir='./modified_onnx'):
print("saving model...")
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
@ -218,6 +222,7 @@ class onnxModifier:
# I turn off the onnx checker as a workaround.
# onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path)
print("model saved in {} !".format(save_dir))
def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None):
import onnxruntime as rt

Loading…
Cancel
Save