`add node` feature works for some simple node like Abs

1123
ZhangGe6 4 years ago
parent 0face73b86
commit 975682eafe

@ -22,9 +22,9 @@ def modify_and_download_model():
modify_info = request.get_json() modify_info = request.get_json()
# print(modify_info) # print(modify_info)
onnx_modifier.reload() # allow for downloading for multiple times onnx_modifier.reload() # allow downloading for multiple times
onnx_modifier.remove_node_by_node_states(modify_info['node_states'])
onnx_modifier.modify_node_io_name(modify_info['node_renamed_io']) onnx_modifier.modify(modify_info)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
return 'OK', 200 return 'OK', 200

@ -94,7 +94,37 @@ class onnxModifier:
for i in range(len(node.output)): for i in range(len(node.output)):
if node.output[i] == src_name: if node.output[i] == src_name:
node.output[i] = dst_name node.output[i] = dst_name
# print(node.input, node.output)
def add_node(self, nodes_info):
for node_info in nodes_info.values():
name = node_info['properties']['name']
op_type = node_info['properties']['op_type']
attributes = node_info['attributes']
inputs = []
for key in node_info['inputs'].keys():
inputs += node_info['inputs'][key]
outputs = []
for key in node_info['outputs'].keys():
outputs += node_info['outputs'][key]
node = onnx.helper.make_node(
op_type=op_type,
inputs=inputs,
outputs=outputs,
name=name,
**attributes
)
# print(node)
self.graph.node.append(node)
def modify(self, modify_info):
self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io'])
self.add_node(modify_info['added_node_info'])
def check_and_save_model(self, save_dir='./modified_onnx'): def check_and_save_model(self, save_dir='./modified_onnx'):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
@ -111,9 +141,10 @@ class onnxModifier:
if __name__ == "__main__": if __name__ == "__main__":
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path) onnx_modifier = onnxModifier.from_model_path(model_path)
def remove_node_by_node_states(): def remove_node_by_node_states():
@ -167,11 +198,21 @@ if __name__ == "__main__":
# for initializer in onnx_modifier.model_proto.graph.initializer: # for initializer in onnx_modifier.model_proto.graph.initializer:
# print(initializer.name) # print(initializer.name)
# print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale']) # print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale'])
pass
# explore_basic() # explore_basic()
def test_modify_node_io_name(): def test_modify_node_io_name():
node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}} node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}}
onnx_modifier.modify_node_io_name(node_rename_io) onnx_modifier.modify_node_io_name(node_rename_io)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
test_modify_node_io_name() # test_modify_node_io_name()
def test_add_node():
node_info = {'properties': {'domain': 'ai.onnx', 'op_type': 'Abs', 'name': 'custom_added_Abs0'}, 'attributes': {}, 'inputs': {'X': ['custom_input_0']}, 'outputs': {'Y': ['custom_output_1']}}
onnx_modifier.add_node(node_info)
onnx_modifier.check_and_save_model()
test_add_node()

@ -214,7 +214,9 @@ host.BrowserHost = class {
const downloadButton = this.document.getElementById('download-graph'); const downloadButton = this.document.getElementById('download-graph');
downloadButton.addEventListener('click', () => { downloadButton.addEventListener('click', () => {
// console.log(this)
console.log(this._view._graph._addedNode)
console.log(this._view._graph._renameMap)
// https://healeycodes.com/talking-between-languages // https://healeycodes.com/talking-between-languages
fetch('/download', { fetch('/download', {
// Declare what type of data we're sending // Declare what type of data we're sending
@ -226,6 +228,7 @@ host.BrowserHost = class {
body: JSON.stringify({ body: JSON.stringify({
'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State), 'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State),
'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap), 'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap),
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode))
} }
) )
@ -640,6 +643,22 @@ host.BrowserHost = class {
} }
return lo return lo
} }
// convert view.LightNodeInfo to Map object for easier transmission to Python backend
parseLightNodeInfo2Map(nodes_info) {
var res_map = new Map()
for (const [modelNodeName, node_info] of nodes_info) {
var node_info_map = new Map()
node_info_map.set('properties', node_info.properties)
node_info_map.set('attributes', node_info.attributes)
node_info_map.set('inputs', node_info.inputs)
node_info_map.set('outputs', node_info.outputs)
res_map.set(modelNodeName, node_info_map)
}
return res_map
}
}; };
host.Dropdown = class { host.Dropdown = class {

@ -540,7 +540,7 @@ onnx.Graph = class {
const input = schema.inputs[i] const input = schema.inputs[i]
var node_info_input = node_info.inputs.get(input.name) var node_info_input = node_info.inputs.get(input.name)
console.log(node_info_input) // console.log(node_info_input)
var arg_list = [] var arg_list = []
if (input.list) { if (input.list) {

@ -682,6 +682,7 @@ class NodeAttributeView {
var attr_input = document.createElement("INPUT"); var attr_input = document.createElement("INPUT");
attr_input.setAttribute("type", "text"); attr_input.setAttribute("type", "text");
attr_input.setAttribute("size", "42");
attr_input.setAttribute("value", content ? content : 'undefined'); attr_input.setAttribute("value", content ? content : 'undefined');
attr_input.addEventListener('input', (e) => { attr_input.addEventListener('input', (e) => {
// console.log(e.target.value); // console.log(e.target.value);
@ -855,6 +856,7 @@ sidebar.ArgumentView = class {
var arg_input = document.createElement("INPUT"); var arg_input = document.createElement("INPUT");
arg_input.setAttribute("type", "text"); arg_input.setAttribute("type", "text");
arg_input.setAttribute("size", "42");
arg_input.setAttribute("value", name); arg_input.setAttribute("value", name);
arg_input.addEventListener('input', (e) => { arg_input.addEventListener('input', (e) => {
// console.log(this._argument) // console.log(this._argument)

@ -875,21 +875,21 @@ view.View = class {
var node = this._graphs[0].make_custom_add_node(node_info) var node = this._graphs[0].make_custom_add_node(node_info)
// console.log(node) // console.log(node)
// padding empty array for LightNodeInfo.inputs/outputs (only when initializing) for (const input of node.inputs) {
if (this.lastViewGraph._addedNode.get(modelNodeName).inputs.size == 0) { var input_list_names = []
for (const arg of input._arguments) {
for (var input of node.inputs) { input_list_names.push(arg.name)
var arg_len = input._arguments.length
this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, new Array(arg_len))
} }
} this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names)
if (this.lastViewGraph._addedNode.get(modelNodeName).outputs.size == 0) { }
for (var output of node.outputs) { for (const output of node.outputs) {
var arg_len = output._arguments.length var output_list_names = []
this.lastViewGraph._addedNode.get(modelNodeName).outputs.set(output.name, new Array(arg_len)) for (const arg of output._arguments) {
output_list_names.push(arg.name)
} }
this.lastViewGraph._addedNode.get(modelNodeName).outputs.set(output.name, output_list_names)
} }
} }
@ -1152,30 +1152,8 @@ view.Graph = class extends grapher.Graph {
this._addedNode.set(modelNodeName, new view.LightNodeInfo(properties)) this._addedNode.set(modelNodeName, new view.LightNodeInfo(properties))
// console.log(this._addedNode) // console.log(this._addedNode)
// refresh
// this.refresh_added_node()
} }
// refresh_added_node() {
// this.view._graphs[0].reset_custom_added_node()
// // for (const node_info of this._addedNode.values()) {
// for (const [modelNodeName, node_info] of this._addedNode) {
// // console.log(node)
// var node = this.view._graphs[0].make_custom_add_node(node_info)
// // padding empty array for LightNodeInfo.inputs/outputs
// for (var input of node.inputs) {
// var arg_len = input._arguments.length
// this._addedNode.get(modelNodeName).inputs.set(input.name, new Array(arg_len))
// }
// }
// // console.log(this.view._graphs[0].nodes)
// console.log(this._addedNode)
// }
changeNodeAttribute(modelNodeName, attributeName, targetValue) { changeNodeAttribute(modelNodeName, attributeName, targetValue) {
if (this._addedNode.has(modelNodeName)) { if (this._addedNode.has(modelNodeName)) {
this._addedNode.get(modelNodeName).attributes.set(attributeName, targetValue) this._addedNode.get(modelNodeName).attributes.set(attributeName, targetValue)

Loading…
Cancel
Save