support removing legacy isolated nodes (like Constant) automatically (https://github.com/ZhangGe6/onnx-modifier/issues/21)

1123
ZhangGe6 3 years ago
parent d4a29ef241
commit e6ec898f41

1
.gitignore vendored

@ -11,5 +11,6 @@ dist/
# self maintained files
gym/
local_tmp/
*ppt
*pptx

@ -1,5 +1,6 @@
# TODO
- [ ] add `shape inference` feature (mentioned in [this issue](https://github.com/ZhangGe6/onnx-modifier/issues/22))
- [ ] **ensure the model is fully loaded before modify() is called.**
- otherwise `NameError: name 'onnx_modifier' is not defined` error will be invoked.
- [ ] support desktop application.
@ -7,13 +8,13 @@
- [ ] Linux
- [ ] support more flexible downloading schema
- [ ] As this [request](https://github.com/ZhangGe6/onnx-modifier/pull/5) notes, the current downloading schema prevents `onnx-modifier ` from being deployed remotely as a service.
- [ ] support adding more complicated nodes (which has some simple parameters like `reshape`).
- [ ] support combine models.
- [ ] support user-defined input/output number when the type of node's input/output is list.
- [ ] slim the code.
- [ ] because some `.js` files (like electron.js and even python.js) in the `static` folder and `electron.html` in `templates` folder are legacy of Netron and can be further slimmed.
- [x] support adding model input/output node.
- [x] fix issue that "extra model inputs" emerges after deleting nodes. [issue#12](https://github.com/ZhangGe6/onnx-modifier/issues/12)
- [x] support adding more complicated nodes (which has some simple parameters like `reshape`).
# Some known reference issues/feature requests

@ -119,21 +119,21 @@ class onnxModifier:
# print('removing node {} ...'.format(node_name))
self.remove_node_by_name(node_name)
# remove node initializers (parameters), aka, keep and only keep the initializers of left nodes
left_node_inputs = []
for left_node in self.graph.node:
left_node_inputs += left_node.input
remained_node_inputs = []
for remained_node in self.graph.node:
remained_node_inputs += remained_node.input
# remove node initializers (parameters), aka, keep and only keep the initializers of remained nodes
for init_name in self.initializer_name2module.keys():
if not init_name in left_node_inputs:
if not init_name in remained_node_inputs:
self.initializer.remove(self.initializer_name2module[init_name])
# remove the (model) inputs related to deleted nodes
# https://github.com/ZhangGe6/onnx-modifier/issues/12
for input_name in self.graph_input_names:
if not input_name in left_node_inputs:
if not input_name in remained_node_inputs:
self.graph.input.remove(self.node_name2module[input_name])
def modify_node_io_name(self, node_renamed_io):
for node_name in node_renamed_io.keys():
if node_name not in self.node_name2module.keys():
@ -192,7 +192,9 @@ class onnxModifier:
# filter out the deleted custom-added outputs
value_info_protos = []
shape_info = onnx.shape_inference.infer_shapes(self.model_proto)
print(added_output_names)
for value_info in shape_info.graph.value_info:
print(value_info.name)
if value_info.name in added_output_names:
value_info_protos.append(value_info)
self.graph.output.extend(value_info_protos)
@ -221,6 +223,33 @@ class onnxModifier:
self.initializer.append(initializer_tensor)
self.initializer_name2module[init_name] = initializer_tensor
def post_process(self):
def remove_isolated_nodes():
# remove the remained corresponding isolated nodes, like Constant
remained_node_inputs, remained_node_outputs = [], []
for remained_node in self.graph.node:
remained_node_inputs += remained_node.input
remained_node_outputs += remained_node.output
for remained_node in self.graph.node:
# delete the node if it does not serve as the input or output of any other nodes
unused = True
for output in remained_node.output:
if output in remained_node_inputs:
unused = False
break
for input in remained_node.input:
if input in remained_node_outputs:
unused = False
break
if unused:
self.graph.node.remove(self.node_name2module[remained_node.name])
for inp in remained_node.input:
if inp in self.initializer_name2module.keys():
self.initializer.remove(self.initializer_name2module[inp])
remove_isolated_nodes()
def modify(self, modify_info):
'''
1. Some functions, such as modify_initializer(), should be placed
@ -242,6 +271,8 @@ class onnxModifier:
self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr'])
self.add_outputs(modify_info['added_outputs'])
self.post_process()
def check_and_save_model(self, save_dir='./modified_onnx'):
print("saving model...")
@ -280,8 +311,9 @@ if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_test_edit_init.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_test_edit_init.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\tiny_squeezenet1.0-3.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic():
@ -366,12 +398,12 @@ if __name__ == "__main__":
def test_inference():
onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=["onnx::Transpose_368"])
test_inference()
# test_inference()
def test_add_output():
# print(onnx_modifier.graph.output)
onnx_modifier.add_outputs(['fire2/squeeze1x1_1'])
# print(onnx_modifier.graph.output)
onnx_modifier.add_outputs({'0': 'out'})
print(onnx_modifier.graph.output)
onnx_modifier.check_and_save_model()
# test_add_output()
@ -406,4 +438,10 @@ if __name__ == "__main__":
# print(onnx_modifier.initializer_name2module.keys())
# for initializer in onnx_modifier.initializer:
# print(f"Tensor Name: {initializer.name}, Data Type: {initializer.data_type}, Shape: {initializer.dims}")
# test_modify_new_initializer()
# test_modify_new_initializer()
def test_remove_isolated_nodes():
modify_info = {'node_states': {'data_0': 'Exist', 'Conv0': 'Exist', 'Relu1': 'Exist', 'MaxPool2': 'Exist', 'Conv3': 'Exist', 'Relu4': 'Exist', 'Conv5': 'Exist', 'Relu6': 'Exist', 'Conv7': 'Exist', 'Relu8': 'Exist', 'Concat9': 'Exist', 'Conv10': 'Exist'}, 'node_renamed_io': {'Conv3': {'pool1_1': 'conv1_2'}, 'MaxPool2': {'conv1_2': 'conv1'}}, 'node_changed_attr': {}, 'added_node_info': {}, 'added_outputs': {}, 'rebatch_info': {}, 'changed_initializer': {}}
onnx_modifier.modify(modify_info)
onnx_modifier.check_and_save_model()
test_remove_isolated_nodes()

@ -722,8 +722,8 @@ host.BrowserHost = class {
// convert view.LightNodeInfo to Map object for easier transmission to Python backend
parseAddedLightNodeInfo2Map(nodes_info, initializer_info) {
console.log(nodes_info)
console.log(initializer_info)
// console.log(nodes_info)
// console.log(initializer_info)
var res_map = new Map()
for (const [modelNodeName, node_info] of nodes_info) {
var node_info_map = new Map()

Loading…
Cancel
Save