|
|
|
|
@ -2,14 +2,30 @@
|
|
|
|
|
# https://github.com/saurabh-shandilya/onnx-utils
|
|
|
|
|
# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model
|
|
|
|
|
|
|
|
|
|
import io
|
|
|
|
|
import os
|
|
|
|
|
import copy
|
|
|
|
|
import onnx
|
|
|
|
|
|
|
|
|
|
class onnxModifier:
|
|
|
|
|
def __init__(self, model_name, model_proto):
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.model_proto = model_proto
|
|
|
|
|
self.model_proto_backup = model_proto
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_path(cls, model_path):
|
|
|
|
|
model_name = os.path.basename(model_path)
|
|
|
|
|
model_proto = onnx.load(model_path)
|
|
|
|
|
return cls(model_name, model_proto)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_name_stream(cls, name, stream):
|
|
|
|
|
# https://leimao.github.io/blog/ONNX-IO-Stream/
|
|
|
|
|
stream.seek(0)
|
|
|
|
|
model_proto = onnx.load_model(stream, onnx.ModelProto)
|
|
|
|
|
return cls(name, model_proto)
|
|
|
|
|
|
|
|
|
|
def reload(self):
|
|
|
|
|
self.model_proto = copy.deepcopy(self.model_proto_backup)
|
|
|
|
|
self.graph = self.model_proto.graph
|
|
|
|
|
self.initializer = self.model_proto.graph.initializer
|
|
|
|
|
|
|
|
|
|
@ -34,20 +50,7 @@ class onnxModifier:
|
|
|
|
|
self.initilizer_name2module = dict()
|
|
|
|
|
for initializer in self.initializer:
|
|
|
|
|
self.initilizer_name2module[initializer.name] = initializer
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_model_path(cls, model_path):
|
|
|
|
|
model_name = os.path.basename(model_path)
|
|
|
|
|
model_proto = onnx.load(model_path)
|
|
|
|
|
return cls(model_name, model_proto)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_name_stream(cls, name, stream):
|
|
|
|
|
# https://leimao.github.io/blog/ONNX-IO-Stream/
|
|
|
|
|
stream.seek(0)
|
|
|
|
|
model_proto = onnx.load_model(stream, onnx.ModelProto)
|
|
|
|
|
return cls(name, model_proto)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_node_by_name(self, node_name):
|
|
|
|
|
# remove node in graph
|
|
|
|
|
self.graph.node.remove(self.node_name2module[node_name])
|
|
|
|
|
|