ZhangGe6 4 years ago
parent f45da8d4fb
commit 38649877eb

@ -109,11 +109,12 @@ class onnxModifier:
# self.initializer.remove(self.initializer_name2module[init_name]) # self.initializer.remove(self.initializer_name2module[init_name])
def modify_node_io_name(self, node_renamed_io): def modify_node_io_name(self, node_renamed_io):
# print(node_renamed_io)
for node_name in node_renamed_io.keys(): for node_name in node_renamed_io.keys():
if node_name not in self.node_name2module.keys():
# custom added nodes or custom added model outputs
continue
renamed_ios = node_renamed_io[node_name] renamed_ios = node_renamed_io[node_name]
for src_name, dst_name in renamed_ios.items(): for src_name, dst_name in renamed_ios.items():
# print(src_name, dst_name)
node = self.node_name2module[node_name] node = self.node_name2module[node_name]
if node_name in self.graph_input_names: if node_name in self.graph_input_names:
node.name = dst_name node.name = dst_name
@ -149,16 +150,28 @@ class onnxModifier:
self.graph.node.append(node) self.graph.node.append(node)
def add_outputs(self, added_outputs, node_states):
# https://github.com/onnx/onnx/issues/3277#issuecomment-1050600445
added_output_names = added_outputs.values()
# filter out the deleted custom-added outputs
value_info_protos = []
shape_info = onnx.shape_inference.infer_shapes(self.model_proto)
for value_info in shape_info.graph.value_info:
if value_info.name in added_output_names:
value_info_protos.append(value_info)
self.graph.output.extend(value_info_protos)
def modify(self, modify_info): def modify(self, modify_info):
# print(modify_info['node_states']) # print(modify_info['node_states'])
# print(modify_info['node_renamed_io']) # print(modify_info['node_renamed_io'])
# print(modify_info['node_changed_attr']) # print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info']) # print(modify_info['added_node_info'])
# print(modify_info['added_outputs'])
self.remove_node_by_node_states(modify_info['node_states']) self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr']) self.modify_node_attr(modify_info['node_changed_attr'])
self.add_node(modify_info['added_node_info'], modify_info['node_states']) self.add_node(modify_info['added_node_info'], modify_info['node_states'])
self.add_outputs(modify_info['added_outputs'], modify_info['node_states'])
def check_and_save_model(self, save_dir='./modified_onnx'): def check_and_save_model(self, save_dir='./modified_onnx'):
@ -191,15 +204,15 @@ class onnxModifier:
# This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506 # This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506
out = inference_session.run(None, {input_name: x})[0] out = inference_session.run(None, {input_name: x})[0]
# print(out) print(out.shape)
if __name__ == "__main__": if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" model_path = "C:\\Users\\ZhangGe\\Desktop\\with-added-output-modified_modified_squeezenet1.0-12.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" // TODO: this model is not supported well , but why?
# model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path) onnx_modifier = onnxModifier.from_model_path(model_path)
@ -260,7 +273,6 @@ if __name__ == "__main__":
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
# remove_node_by_node_states() # remove_node_by_node_states()
def test_modify_node_io_name(): def test_modify_node_io_name():
node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}} node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}}
onnx_modifier.modify_node_io_name(node_rename_io) onnx_modifier.modify_node_io_name(node_rename_io)
@ -285,34 +297,14 @@ if __name__ == "__main__":
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
# test_change_node_attr() # test_change_node_attr()
def test_inference():
onnx_modifier.inference()
test_inference()
def debug_remove_node_by_node_states(): def test_add_output():
# print(len(onnx_modifier.graph.node)) # print(onnx_modifier.graph.output)
# print(len(onnx_modifier.graph.initializer)) onnx_modifier.add_outputs(['fire2/squeeze1x1_1'])
# print(onnx_modifier.graph.output)
# print(onnx_modifier.node_name2module.keys())
# print(onnx_modifier.graph.node)
# for node in onnx_modifier.graph.node:
# print(node.name)
# print(node.input)
# print(node.output)
# print('\noriginal input')
# for inp in onnx_modifier.graph.input:
# print(inp.name)
node_states = {'data_0': 'Exist', 'Conv0': 'Exist', 'Relu1': 'Exist', 'MaxPool2': 'Exist', 'Conv3': 'Exist', 'Relu4': 'Exist', 'Conv5': 'Exist', 'Relu6': 'Exist', 'Conv7': 'Exist', 'Relu8': 'Exist', 'Concat9': 'Exist', 'Conv10': 'Exist', 'Relu11': 'Exist', 'Conv12': 'Exist', 'Relu13': 'Exist', 'Conv14': 'Exist', 'Relu15': 'Exist', 'Concat16': 'Exist', 'MaxPool17': 'Exist', 'Conv18': 'Exist', 'Relu19': 'Exist', 'Conv20': 'Exist', 'Relu21': 'Exist', 'Conv22': 'Exist', 'Relu23': 'Exist', 'Concat24': 'Exist', 'Conv25': 'Exist', 'Relu26': 'Exist', 'Conv27': 'Exist', 'Relu28': 'Exist', 'Conv29': 'Exist', 'Relu30': 'Exist', 'Concat31': 'Exist', 'MaxPool32': 'Exist', 'Conv33': 'Exist', 'Relu34': 'Exist', 'Conv35': 'Exist', 'Relu36': 'Exist', 'Conv37': 'Exist', 'Relu38': 'Exist', 'Concat39': 'Exist', 'Conv40': 'Exist', 'Relu41': 'Exist', 'Conv42': 'Exist', 'Relu43': 'Exist', 'Conv44': 'Exist', 'Relu45': 'Exist', 'Concat46': 'Exist', 'Conv47': 'Exist', 'Relu48': 'Exist', 'Conv49': 'Exist', 'Relu50': 'Exist', 'Conv51': 'Exist', 'Relu52': 'Exist', 'Concat53': 'Exist', 'Conv54': 'Exist', 'Relu55': 'Deleted', 'Conv56': 'Deleted', 'Relu57': 'Deleted', 'Conv58': 'Deleted', 'Relu59': 'Deleted', 'Concat60': 'Deleted', 'Dropout61': 'Deleted', 'Conv62': 'Deleted', 'Relu63': 'Deleted', 'GlobalAveragePool64': 'Deleted', 'Softmax65': 'Deleted', 'out_softmaxout_1': 'Deleted'}
# print('\graph input')
# for inp in onnx_modifier.graph.input:
# print(inp.name)
onnx_modifier.remove_node_by_node_states(node_states)
print('\nleft input')
for inp in onnx_modifier.graph.input:
print(inp.name)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
debug_remove_node_by_node_states() # test_add_output()

@ -229,10 +229,10 @@ host.BrowserHost = class {
'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State), 'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State),
'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap), 'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap),
'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes), 'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes),
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)) 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)),
} 'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs,
this._view._graph._renameMap, this._view._graph._modelNodeName2State))
) })
}).then(function (response) { }).then(function (response) {
return response.text(); return response.text();
}).then(function (text) { }).then(function (text) {
@ -264,9 +264,6 @@ host.BrowserHost = class {
this._view._updateGraph(); this._view._updateGraph();
}) })
this.document.getElementById('version').innerText = this.version; this.document.getElementById('version').innerText = this.version;
if (this._meta.file) { if (this._meta.file) {
@ -678,6 +675,33 @@ host.BrowserHost = class {
return lo return lo
} }
// this function does 2 things:
// 1. rename the addedOutputs with their new names using renameMap. Because addedOutputs are stored in lists,
// it may be not easy to rename them while editing. (Of course there may be a better way to do this)
// 2. filter out the custom output which is added, but deleted later
process_added_outputs(addedOutputs, renameMap, modelNodeName2State) {
var processed = []
for (let i = 0; i < addedOutputs.length; ++i) {
if (modelNodeName2State.get("out_" + addedOutputs[i]) == "Exist") {
processed.push(addedOutputs[i]);
}
}
for (let i = 0; i < processed.length; ++i) {
if (renameMap.get("out_" + processed[i])) {
processed[i] = renameMap.get("out_" + processed[i]).get(processed[i]);
}
}
return processed;
}
// https://stackoverflow.com/a/4215753/10096987
arrayToObject(arr) {
var rv = {};
for (var i = 0; i < arr.length; ++i)
if (arr[i] !== undefined) rv[i] = arr[i];
return rv;
}
// convert view.LightNodeInfo to Map object for easier transmission to Python backend // convert view.LightNodeInfo to Map object for easier transmission to Python backend
parseLightNodeInfo2Map(nodes_info) { parseLightNodeInfo2Map(nodes_info) {
var res_map = new Map() var res_map = new Map()

@ -438,6 +438,7 @@ onnx.Graph = class {
this._custom_add_node_io_idx = 0 this._custom_add_node_io_idx = 0
this._custom_added_node = [] this._custom_added_node = []
this._custom_added_outputs = []
// model parameter assignment here! // model parameter assignment here!
// console.log(graph) // console.log(graph)
@ -504,7 +505,8 @@ onnx.Graph = class {
} }
get outputs() { get outputs() {
return this._outputs; // return this._outputs;
return this._outputs.concat(this._custom_added_outputs);
} }
get nodes() { get nodes() {
@ -632,6 +634,15 @@ onnx.Graph = class {
return custom_add_node; return custom_add_node;
} }
reset_custom_added_outputs() {
this._custom_added_outputs = [];
}
add_output(name) {
const argument = this._context.argument(name);
this._custom_added_outputs.push(new onnx.Parameter(name, [ argument ]));
}
}; };
onnx.Parameter = class { onnx.Parameter = class {

@ -24,6 +24,7 @@ grapher.Graph = class {
this._changedAttributes = new Map(); this._changedAttributes = new Map();
this._addedNode = new Map(); this._addedNode = new Map();
this._addedOutputs = [];
} }
get options() { get options() {

@ -210,6 +210,9 @@ sidebar.NodeSidebar = class {
this.add_separator(this._elements, 'sidebar-view-separator') this.add_separator(this._elements, 'sidebar-view-separator')
this._addButton('Enter'); this._addButton('Enter');
this._addHeader('Output adding helper');
this._addButton('Add Output');
// deprecated // deprecated
// this.add_separator(this._elements, 'sidebar-view-separator'); // this.add_separator(this._elements, 'sidebar-view-separator');
// this._addHeader('Rename helper'); // this._addHeader('Rename helper');
@ -272,8 +275,6 @@ sidebar.NodeSidebar = class {
} }
} }
} }
} }
render() { render() {
@ -356,7 +357,11 @@ sidebar.NodeSidebar = class {
this._host._view._updateGraph() this._host._view._updateGraph()
}); });
} }
if (title === 'Add Output') {
buttonElement.addEventListener('click', () => {
this._host._view._graph.add_output(this._modelNodeName)
});
}
} }
// deprecated // deprecated

@ -464,7 +464,6 @@ view.View = class {
this.refreshModelInputOutput() this.refreshModelInputOutput()
this.refreshNodeArguments() this.refreshNodeArguments()
this.refreshNodeAttributes() this.refreshNodeAttributes()
} }
return active_graph return active_graph
@ -580,7 +579,8 @@ view.View = class {
viewGraph._renameMap = this.lastViewGraph._renameMap; viewGraph._renameMap = this.lastViewGraph._renameMap;
viewGraph._changedAttributes = this.lastViewGraph._changedAttributes; viewGraph._changedAttributes = this.lastViewGraph._changedAttributes;
viewGraph._addedNode = this.lastViewGraph._addedNode; viewGraph._addedNode = this.lastViewGraph._addedNode;
viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey;
viewGraph._addedOutputs = this.lastViewGraph._addedOutputs;
// console.log(viewGraph._renameMap); // console.log(viewGraph._renameMap);
// console.log(viewGraph._modelNodeName2State) // console.log(viewGraph._modelNodeName2State)
} }
@ -866,7 +866,7 @@ view.View = class {
} }
// re-generate the added node according to _addedNode // re-generate the added node according to _addedNode according to the latest _addedNode
refreshAddedNode() { refreshAddedNode() {
this._graphs[0].reset_custom_added_node() this._graphs[0].reset_custom_added_node()
// for (const node_info of this._addedNode.values()) { // for (const node_info of this._addedNode.values()) {
@ -881,7 +881,6 @@ view.View = class {
input_list_names.push(arg.name) input_list_names.push(arg.name)
} }
this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names) this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names)
} }
for (const output of node.outputs) { for (const output of node.outputs) {
@ -979,7 +978,13 @@ view.View = class {
} }
} }
for (var output of this._graphs[0]._outputs) { // create and add new output to graph
this._graphs[0].reset_custom_added_outputs();
for (var output_name of this.lastViewGraph._addedOutputs) {
this._graphs[0].add_output(output_name);
}
// console.log(this._graphs[0].outputs)
for (var output of this._graphs[0].outputs) {
var output_orig_name = output.arguments[0].original_name var output_orig_name = output.arguments[0].original_name
if (this.lastViewGraph._renameMap.get('out_' + output_orig_name)) { if (this.lastViewGraph._renameMap.get('out_' + output_orig_name)) {
// for model input and output, node.modelNodeName == element.original_name // for model input and output, node.modelNodeName == element.original_name
@ -1011,6 +1016,7 @@ view.View = class {
} }
} }
} }
// console.log(this.lastViewGraph._renameMap)
} }
} }
} }
@ -1190,7 +1196,6 @@ view.Graph = class extends grapher.Graph {
} }
} }
for (const output of graph.outputs) { for (const output of graph.outputs) {
const viewOutput = this.createOutput(output); const viewOutput = this.createOutput(output);
for (const argument of output.arguments) { for (const argument of output.arguments) {
@ -1238,6 +1243,17 @@ view.Graph = class extends grapher.Graph {
} }
} }
add_output(node_name) {
var model_node = this._modelNodeName2ModelNode.get(node_name);
for (var output of model_node.outputs) {
for (var argument of output.arguments) {
this._addedOutputs.push(argument.name);
}
}
// console.log(this._addedOutputs);
this.view._updateGraph();
}
resetGraph() { resetGraph() {
// reset node states // reset node states
for (const nodeId of this.nodes.keys()) { for (const nodeId of this.nodes.keys()) {
@ -1277,14 +1293,14 @@ view.Graph = class extends grapher.Graph {
} }
} }
} }
this._renameMap = new Map(); this._renameMap = new Map();
// clear custom added nodes // clear custom added nodes
this._addedNode = new Map() this._addedNode = new Map()
this.view._graphs[0].reset_custom_added_node() this.view._graphs[0].reset_custom_added_node()
this._addedOutputs = []
this.view._graphs[0].reset_custom_added_outputs()
} }
recordRenameInfo(modelNodeName, src_name, dst_name) { recordRenameInfo(modelNodeName, src_name, dst_name) {

Loading…
Cancel
Save