|
|
|
|
@ -25,8 +25,10 @@ class onnxModifier:
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_name_stream(cls, name, stream):
|
|
|
|
|
# https://leimao.github.io/blog/ONNX-IO-Stream/
|
|
|
|
|
print("loading model...")
|
|
|
|
|
stream.seek(0)
|
|
|
|
|
model_proto = onnx.load_model(stream, onnx.ModelProto, load_external_data=False)
|
|
|
|
|
print("load done!")
|
|
|
|
|
return cls(name, model_proto)
|
|
|
|
|
|
|
|
|
|
def reload(self):
|
|
|
|
|
@ -61,6 +63,7 @@ class onnxModifier:
|
|
|
|
|
self.initializer_name2module[initializer.name] = initializer
|
|
|
|
|
|
|
|
|
|
def change_batch_size(self, rebatch_info):
|
|
|
|
|
if not (rebatch_info): return
|
|
|
|
|
# https://github.com/onnx/onnx/issues/2182#issuecomment-881752539
|
|
|
|
|
rebatch_type = rebatch_info['type']
|
|
|
|
|
rebatch_value = rebatch_info['value']
|
|
|
|
|
@ -183,9 +186,10 @@ class onnxModifier:
|
|
|
|
|
|
|
|
|
|
self.graph.node.append(node)
|
|
|
|
|
|
|
|
|
|
def add_outputs(self, added_outputs, node_states):
|
|
|
|
|
def add_outputs(self, added_outputs):
|
|
|
|
|
# https://github.com/onnx/onnx/issues/3277#issuecomment-1050600445
|
|
|
|
|
added_output_names = added_outputs.values()
|
|
|
|
|
if len(added_output_names) == 0: return
|
|
|
|
|
# filter out the deleted custom-added outputs
|
|
|
|
|
value_info_protos = []
|
|
|
|
|
shape_info = onnx.shape_inference.infer_shapes(self.model_proto)
|
|
|
|
|
@ -205,10 +209,10 @@ class onnxModifier:
|
|
|
|
|
self.modify_node_io_name(modify_info['node_renamed_io'])
|
|
|
|
|
self.modify_node_attr(modify_info['node_changed_attr'])
|
|
|
|
|
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
|
|
|
|
|
self.add_outputs(modify_info['added_outputs'], modify_info['node_states'])
|
|
|
|
|
|
|
|
|
|
self.add_outputs(modify_info['added_outputs'])
|
|
|
|
|
|
|
|
|
|
def check_and_save_model(self, save_dir='./modified_onnx'):
|
|
|
|
|
print("saving model...")
|
|
|
|
|
if not os.path.exists(save_dir):
|
|
|
|
|
os.mkdir(save_dir)
|
|
|
|
|
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
|
|
|
|
|
@ -218,6 +222,7 @@ class onnxModifier:
|
|
|
|
|
# I turn off the onnx checker as a workaround.
|
|
|
|
|
# onnx.checker.check_model(self.model_proto)
|
|
|
|
|
onnx.save(self.model_proto, save_path)
|
|
|
|
|
print("model saved in {} !".format(save_dir))
|
|
|
|
|
|
|
|
|
|
def inference(self, input_shape=[1, 3, 224, 224], x=None, output_names=None):
|
|
|
|
|
import onnxruntime as rt
|
|
|
|
|
|