Adding node feature is tested for some ops and works as expected. Currently, adding node with initializers is not supported

1123
ZhangGe6 4 years ago
parent 7590775f74
commit bb58bad176

@ -8,6 +8,10 @@ add node NMS: https://github.com/onnx/onnx/issues/2216
add node (add preprocess nodes): https://zhuanlan.zhihu.com/p/394395167
combine models:
- https://stackoverflow.com/questions/66178085/can-i-combine-two-onnx-graphs-together-passing-the-output-from-one-as-input-to
- https://www.zhihu.com/people/kai-xin-zui-zhong-yao-76/posts
# modify attribute of nodes
topk: https://github.com/onnx/onnx/issues/2921
@ -26,11 +30,14 @@ http://yyixx.com/docs/algo/onnx/
# 待做的
bug(fixed): 不可连续添加某一种类型的节点(无反应)
boost: 支持添加更复杂的节点
boost: 直接使用侧边栏inputs/outputs属性框完成重命名并提供reset功能
boost: 支持处理属性的修改
boost: 支持添加更复杂的节点
question: 在add()函数里为什么对conv的inputs进行遍历只能得到X而得不到W和B
- 暂时只支持新增节点,而不支持已有节点开放(防止出现不必要的错误)
(fixed)bug: 不可连续添加某一种类型的节点(无反应)
(solved)question: 在add()函数里为什么对conv的inputs进行遍历只能得到X而得不到W和B
- 因为遍历的条件里有判断是否有initializer的条件
# 其他

@ -5,7 +5,10 @@
import os
import copy
import numpy as np
import onnx
import onnxruntime as rt
from utils import make_node
class onnxModifier:
def __init__(self, model_name, model_proto):
@ -63,6 +66,9 @@ class onnxModifier:
def remove_node_by_node_states(self, node_states):
# remove node from graph
for node_name, node_state in node_states.items():
if not (node_name in self.node_name2module):
# for custom added node here
continue
if node_state == 'Deleted':
if node_name in self.graph_output_names:
# print('removing output {} ...'.format(node_name))
@ -98,31 +104,15 @@ class onnxModifier:
def add_node(self, nodes_info):
for node_info in nodes_info.values():
name = node_info['properties']['name']
op_type = node_info['properties']['op_type']
attributes = node_info['attributes']
inputs = []
for key in node_info['inputs'].keys():
inputs += node_info['inputs'][key]
outputs = []
for key in node_info['outputs'].keys():
outputs += node_info['outputs'][key]
node = onnx.helper.make_node(
op_type=op_type,
inputs=inputs,
outputs=outputs,
name=name,
**attributes
)
node = make_node(node_info)
# print(node)
self.graph.node.append(node)
def modify(self, modify_info):
print(modify_info['node_renamed_io'])
# print(modify_info['node_renamed_io'])
print(modify_info['added_node_info'])
self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io'])
self.add_node(modify_info['added_node_info'])
@ -135,11 +125,29 @@ class onnxModifier:
onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path)
def inference(self):
# model_proto_bytes = onnx._serialize(model_proto_from_stream)
# inference_session = rt.InferenceSession(model_proto_bytes)
pass
def inference(self, x=None, output_names=None):
if not x:
input_shape = [1, 3, 224, 224]
x = np.random.randn(*input_shape).astype(np.float32)
if not output_names:
output_name = self.graph.node[-1].output[0]
# output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.INT64, shape=[])
output_value_info = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape=[])
self.graph.output.append(output_value_info)
model_proto_bytes = onnx._serialize(self.model_proto)
inference_session = rt.InferenceSession(model_proto_bytes)
input_name = inference_session.get_inputs()[0].name
output_name = inference_session.get_outputs()[0].name
# print(input_name)
# print(output_name)
# This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506
out = inference_session.run(None, {input_name: x})[0]
# print(out)
if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
@ -147,7 +155,25 @@ if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic():
print(type(onnx_modifier.model_proto.graph.initializer))
print(dir(onnx_modifier.model_proto.graph.initializer))
print(len(onnx_modifier.model_proto.graph.node))
print(len(onnx_modifier.model_proto.graph.initializer))
for node in onnx_modifier.model_proto.graph.node:
print(node.name)
print(node.input)
print()
# for initializer in onnx_modifier.model_proto.graph.initializer:
# print(initializer.name)
# print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale'])
pass
# explore_basic()
def remove_node_by_node_states():
print(len(onnx_modifier.graph.node))
print(len(onnx_modifier.graph.initializer))
@ -184,24 +210,7 @@ if __name__ == "__main__":
onnx_modifier.check_and_save_model()
# remove_node_by_node_states()
def explore_basic():
print(type(onnx_modifier.model_proto.graph.initializer))
print(dir(onnx_modifier.model_proto.graph.initializer))
print(len(onnx_modifier.model_proto.graph.node))
print(len(onnx_modifier.model_proto.graph.initializer))
for node in onnx_modifier.model_proto.graph.node:
print(node.name)
print(node.input)
print()
# for initializer in onnx_modifier.model_proto.graph.initializer:
# print(initializer.name)
# print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale'])
pass
# explore_basic()
def test_modify_node_io_name():
node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}}
onnx_modifier.modify_node_io_name(node_rename_io)
@ -209,9 +218,11 @@ if __name__ == "__main__":
# test_modify_node_io_name()
def test_add_node():
node_info = {'properties': {'domain': 'ai.onnx', 'op_type': 'Abs', 'name': 'custom_added_Abs0'}, 'attributes': {}, 'inputs': {'X': ['custom_input_0']}, 'outputs': {'Y': ['custom_output_1']}}
node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}}
onnx_modifier.add_node(node_info)
onnx_modifier.inference()
onnx_modifier.check_and_save_model()
test_add_node()

@ -244,7 +244,7 @@ host.BrowserHost = class {
}
else {
// swal("Error happens!", "You are kindly to create an issue on https://github.com/ZhangGe6/onnx-modifier", "error");
swal("Error happens!", "You can check the log and kindly create an issue on https://github.com/ZhangGe6/onnx-modifier", "error");
swal("Error happens!", "You are kindly to check the log and create an issue on https://github.com/ZhangGe6/onnx-modifier", "error");
// alert('Error happens, you can find it out or create an issue on https://github.com/ZhangGe6/onnx-modifier')
}
});

@ -514,7 +514,7 @@ onnx.Graph = class {
reset_custom_added_node() {
this._custom_added_node = []
this._custom_add_node_io_idx = 0
// this._custom_add_node_io_idx = 0
}
toString() {
@ -645,7 +645,7 @@ onnx.Graph = class {
node_info.properties.get('op_type'),
node_info.properties.get('domain'),
node_info.properties.get('name'),
schema.description,
null, // schema.description, // omit it to save sidebar space. The node description can also be seen in the node `type` expander
attributes,
inputs,
outputs
@ -693,9 +693,9 @@ onnx.Argument = class {
// this._renamed = false;
// this._new_name = null;
console.log(original_name)
// console.log(original_name)
this.original_name = original_name || name
console.log(this.original_name)
// console.log(this.original_name)
}

@ -201,24 +201,26 @@ sidebar.NodeSidebar = class {
this._elements.push(this._host.document.createElement('hr'));
this.add_separator(this._elements, 'sidebar-view-separator')
this._addHeader('Node deleting helper');
this._addButton('Delete With Children');
this.add_span()
this._addButton('Delete Single Node');
this.add_span()
this._addButton('Recover Node');
this.add_separator(this._elements, 'sidebar-view-separator');
this._addHeader('Rename helper');
if (inputs && inputs.length > 0) {
for (const input of inputs) {
this.add_rename_aux_element(input.arguments);
}
}
if (outputs && outputs.length > 0) {
for (const output of outputs) {
this.add_rename_aux_element(output.arguments);
}
}
// deprecated
// this.add_separator(this._elements, 'sidebar-view-separator');
// this._addHeader('Rename helper');
// if (inputs && inputs.length > 0) {
// for (const input of inputs) {
// this.add_rename_aux_element(input.arguments);
// }
// }
// if (outputs && outputs.length > 0) {
// for (const output of outputs) {
// this.add_rename_aux_element(output.arguments);
// }
// }
// this.add_separator(this._elements, 'sidebar-view-separator');
// this._addHeader('Add children node');
@ -648,6 +650,7 @@ class NodeAttributeView {
this._element.appendChild(this._expander);
}
const value = this._attribute.value;
console.log(this._attribute.name, value, type)
switch (type) {
case 'graph': {
const line = this._host.document.createElement('div');
@ -687,8 +690,8 @@ class NodeAttributeView {
attr_input.setAttribute("size", "42");
attr_input.setAttribute("value", content ? content : 'undefined');
attr_input.addEventListener('input', (e) => {
// console.log(e.target.value);
this._host._view._graph.changeNodeAttribute(this._modelNodeName, this._attributeName, e.target.value);
console.log(e.target.value);
this._host._view._graph.changeNodeAttribute(this._modelNodeName, this._attributeName, this.parse_value(e.target.value, type));
// console.log(this._host._view._graph._renameMap);
});
@ -758,6 +761,34 @@ class NodeAttributeView {
}
}
}
parse_value(value, type) {
if (value == 'undefined') {
// alert("");
return value
}
switch (type) {
case "int64":
return parseInt(value)
// case ""
case "int64[]":
var val = []
for (var v of value.split(",")) {
val.push(parseInt(v))
}
return val
case "float32":
return parseFloat(value)
case "float32[]":
var val = []
for (var v of value.split(",")) {
val.push(parseFloat(v))
}
return val
}
}
}
sidebar.ParameterView = class {

@ -0,0 +1 @@
from .make_nodes import *

@ -0,0 +1,28 @@
import onnx
def make_node(node_info):
name = node_info['properties']['name']
op_type = node_info['properties']['op_type']
attributes = node_info['attributes']
# attributes = {k: v for k, v in node_info['attributes'].items() if not v == 'undefined'}
# print(attributes)
inputs = []
for key in node_info['inputs'].keys():
inputs += node_info['inputs'][key]
outputs = []
for key in node_info['outputs'].keys():
outputs += node_info['outputs'][key]
# https://github.com/onnx/onnx/blob/main/onnx/helper.py#L82
node = onnx.helper.make_node(
op_type=op_type,
inputs=inputs,
outputs=outputs,
name=name,
**attributes
)
# print(node)
return node
Loading…
Cancel
Save