|
|
|
|
@ -8,7 +8,7 @@ import copy
|
|
|
|
|
import numpy as np
|
|
|
|
|
import onnx
|
|
|
|
|
import onnxruntime as rt
|
|
|
|
|
from utils import make_node
|
|
|
|
|
from utils import make_new_node, make_attr_changed_node
|
|
|
|
|
|
|
|
|
|
class onnxModifier:
|
|
|
|
|
def __init__(self, model_name, model_proto):
|
|
|
|
|
@ -117,12 +117,23 @@ class onnxModifier:
|
|
|
|
|
if node.output[i] == src_name:
|
|
|
|
|
node.output[i] = dst_name
|
|
|
|
|
|
|
|
|
|
def modify_node_attr(self, node_changed_attr):
|
|
|
|
|
# we achieve it by deleting the original node and make a (copied) new node
|
|
|
|
|
# print(node_changed_attr)
|
|
|
|
|
for node_name in node_changed_attr.keys():
|
|
|
|
|
orig_node = self.node_name2module[node_name]
|
|
|
|
|
attr_changed_node = make_attr_changed_node(orig_node, node_changed_attr[node_name])
|
|
|
|
|
self.graph.node.remove(self.node_name2module[node_name])
|
|
|
|
|
self.graph.node.append(attr_changed_node)
|
|
|
|
|
|
|
|
|
|
# update the node_name2module and initializer_name2module
|
|
|
|
|
self.gen_name2module_map()
|
|
|
|
|
|
|
|
|
|
def add_node(self, nodes_info, node_states):
|
|
|
|
|
for node_info in nodes_info.values():
|
|
|
|
|
if node_states[node_info['properties']['name']] == "Deleted":
|
|
|
|
|
continue
|
|
|
|
|
node = make_node(node_info)
|
|
|
|
|
node = make_new_node(node_info)
|
|
|
|
|
# print(node)
|
|
|
|
|
|
|
|
|
|
self.graph.node.append(node)
|
|
|
|
|
@ -131,17 +142,23 @@ class onnxModifier:
|
|
|
|
|
def modify(self, modify_info):
|
|
|
|
|
# print(modify_info['node_states'])
|
|
|
|
|
# print(modify_info['node_renamed_io'])
|
|
|
|
|
print(modify_info['node_changed_attr'])
|
|
|
|
|
# print(modify_info['added_node_info'])
|
|
|
|
|
self.remove_node_by_node_states(modify_info['node_states'])
|
|
|
|
|
self.modify_node_io_name(modify_info['node_renamed_io'])
|
|
|
|
|
self.modify_node_attr(modify_info['node_changed_attr'])
|
|
|
|
|
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_save_model(self, save_dir='./modified_onnx'):
|
|
|
|
|
if not os.path.exists(save_dir):
|
|
|
|
|
os.mkdir(save_dir)
|
|
|
|
|
|
|
|
|
|
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
|
|
|
|
|
onnx.checker.check_model(self.model_proto)
|
|
|
|
|
|
|
|
|
|
# adding new node like self.add_node() and self.modify_node_attr() can not guarantee the nodes are topologically sorted
|
|
|
|
|
# so `onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted` will be invoked
|
|
|
|
|
# I turn off the onnx checker as a workaround.
|
|
|
|
|
# onnx.checker.check_model(self.model_proto)
|
|
|
|
|
onnx.save(self.model_proto, save_path)
|
|
|
|
|
|
|
|
|
|
def inference(self, x=None, output_names=None):
|
|
|
|
|
@ -241,7 +258,7 @@ if __name__ == "__main__":
|
|
|
|
|
# print(inp.name)
|
|
|
|
|
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
remove_node_by_node_states()
|
|
|
|
|
# remove_node_by_node_states()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_node_io_name():
|
|
|
|
|
@ -257,7 +274,17 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
onnx_modifier.inference()
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
|
|
|
|
|
# test_add_node()
|
|
|
|
|
|
|
|
|
|
def test_change_node_attr():
|
|
|
|
|
# changed_attr = {'Clip_3': {'max': 5}}
|
|
|
|
|
changed_attr = {'Conv_2': {'group': 64}}
|
|
|
|
|
|
|
|
|
|
onnx_modifier.modify_node_attr(changed_attr)
|
|
|
|
|
|
|
|
|
|
onnx_modifier.check_and_save_model()
|
|
|
|
|
|
|
|
|
|
test_change_node_attr()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|