ZhangGe6 4 years ago
parent f45da8d4fb
commit 38649877eb

@ -109,11 +109,12 @@ class onnxModifier:
# self.initializer.remove(self.initializer_name2module[init_name])
def modify_node_io_name(self, node_renamed_io):
# print(node_renamed_io)
for node_name in node_renamed_io.keys():
if node_name not in self.node_name2module.keys():
# custom added nodes or custom added model outputs
continue
renamed_ios = node_renamed_io[node_name]
for src_name, dst_name in renamed_ios.items():
# print(src_name, dst_name)
node = self.node_name2module[node_name]
if node_name in self.graph_input_names:
node.name = dst_name
@ -149,16 +150,28 @@ class onnxModifier:
self.graph.node.append(node)
def add_outputs(self, added_outputs, node_states):
# https://github.com/onnx/onnx/issues/3277#issuecomment-1050600445
added_output_names = added_outputs.values()
# filter out the deleted custom-added outputs
value_info_protos = []
shape_info = onnx.shape_inference.infer_shapes(self.model_proto)
for value_info in shape_info.graph.value_info:
if value_info.name in added_output_names:
value_info_protos.append(value_info)
self.graph.output.extend(value_info_protos)
def modify(self, modify_info):
# print(modify_info['node_states'])
# print(modify_info['node_renamed_io'])
# print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info'])
# print(modify_info['added_outputs'])
self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr'])
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
self.add_outputs(modify_info['added_outputs'], modify_info['node_states'])
def check_and_save_model(self, save_dir='./modified_onnx'):
@ -191,15 +204,15 @@ class onnxModifier:
# This issue may be encountered: https://github.com/microsoft/onnxruntime/issues/7506
out = inference_session.run(None, {input_name: x})[0]
# print(out)
print(out.shape)
if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_modified_squeezenet1.0-12.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\with-added-output-modified_modified_squeezenet1.0-12.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" // TODO: this model is not supported well , but why?
# model_path = "C:\\Users\\ZhangGe\\Desktop\\mobilenetv2-7.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
@ -259,7 +272,6 @@ if __name__ == "__main__":
onnx_modifier.check_and_save_model()
# remove_node_by_node_states()
def test_modify_node_io_name():
node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}}
@ -285,34 +297,14 @@ if __name__ == "__main__":
onnx_modifier.check_and_save_model()
# test_change_node_attr()
def debug_remove_node_by_node_states():
# print(len(onnx_modifier.graph.node))
# print(len(onnx_modifier.graph.initializer))
# print(onnx_modifier.node_name2module.keys())
# print(onnx_modifier.graph.node)
# for node in onnx_modifier.graph.node:
# print(node.name)
# print(node.input)
# print(node.output)
# print('\noriginal input')
# for inp in onnx_modifier.graph.input:
# print(inp.name)
node_states = {'data_0': 'Exist', 'Conv0': 'Exist', 'Relu1': 'Exist', 'MaxPool2': 'Exist', 'Conv3': 'Exist', 'Relu4': 'Exist', 'Conv5': 'Exist', 'Relu6': 'Exist', 'Conv7': 'Exist', 'Relu8': 'Exist', 'Concat9': 'Exist', 'Conv10': 'Exist', 'Relu11': 'Exist', 'Conv12': 'Exist', 'Relu13': 'Exist', 'Conv14': 'Exist', 'Relu15': 'Exist', 'Concat16': 'Exist', 'MaxPool17': 'Exist', 'Conv18': 'Exist', 'Relu19': 'Exist', 'Conv20': 'Exist', 'Relu21': 'Exist', 'Conv22': 'Exist', 'Relu23': 'Exist', 'Concat24': 'Exist', 'Conv25': 'Exist', 'Relu26': 'Exist', 'Conv27': 'Exist', 'Relu28': 'Exist', 'Conv29': 'Exist', 'Relu30': 'Exist', 'Concat31': 'Exist', 'MaxPool32': 'Exist', 'Conv33': 'Exist', 'Relu34': 'Exist', 'Conv35': 'Exist', 'Relu36': 'Exist', 'Conv37': 'Exist', 'Relu38': 'Exist', 'Concat39': 'Exist', 'Conv40': 'Exist', 'Relu41': 'Exist', 'Conv42': 'Exist', 'Relu43': 'Exist', 'Conv44': 'Exist', 'Relu45': 'Exist', 'Concat46': 'Exist', 'Conv47': 'Exist', 'Relu48': 'Exist', 'Conv49': 'Exist', 'Relu50': 'Exist', 'Conv51': 'Exist', 'Relu52': 'Exist', 'Concat53': 'Exist', 'Conv54': 'Exist', 'Relu55': 'Deleted', 'Conv56': 'Deleted', 'Relu57': 'Deleted', 'Conv58': 'Deleted', 'Relu59': 'Deleted', 'Concat60': 'Deleted', 'Dropout61': 'Deleted', 'Conv62': 'Deleted', 'Relu63': 'Deleted', 'GlobalAveragePool64': 'Deleted', 'Softmax65': 'Deleted', 'out_softmaxout_1': 'Deleted'}
# print('\graph input')
# for inp in onnx_modifier.graph.input:
# print(inp.name)
onnx_modifier.remove_node_by_node_states(node_states)
print('\nleft input')
for inp in onnx_modifier.graph.input:
print(inp.name)
onnx_modifier.check_and_save_model()
debug_remove_node_by_node_states()
def test_inference():
onnx_modifier.inference()
test_inference()
def test_add_output():
# print(onnx_modifier.graph.output)
onnx_modifier.add_outputs(['fire2/squeeze1x1_1'])
# print(onnx_modifier.graph.output)
onnx_modifier.check_and_save_model()
# test_add_output()

@ -229,10 +229,10 @@ host.BrowserHost = class {
'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State),
'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap),
'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes),
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode))
}
)
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)),
'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs,
this._view._graph._renameMap, this._view._graph._modelNodeName2State))
})
}).then(function (response) {
return response.text();
}).then(function (text) {
@ -264,9 +264,6 @@ host.BrowserHost = class {
this._view._updateGraph();
})
this.document.getElementById('version').innerText = this.version;
if (this._meta.file) {
@ -677,6 +674,33 @@ host.BrowserHost = class {
}
return lo
}
// this function does 2 things:
// 1. rename the addedOutputs with their new names using renameMap. Because addedOutputs are stored in lists,
// it may be not easy to rename them while editing. (Of course there may be a better way to do this)
// 2. filter out the custom output which is added, but deleted later
process_added_outputs(addedOutputs, renameMap, modelNodeName2State) {
var processed = []
for (let i = 0; i < addedOutputs.length; ++i) {
if (modelNodeName2State.get("out_" + addedOutputs[i]) == "Exist") {
processed.push(addedOutputs[i]);
}
}
for (let i = 0; i < processed.length; ++i) {
if (renameMap.get("out_" + processed[i])) {
processed[i] = renameMap.get("out_" + processed[i]).get(processed[i]);
}
}
return processed;
}
// https://stackoverflow.com/a/4215753/10096987
arrayToObject(arr) {
var rv = {};
for (var i = 0; i < arr.length; ++i)
if (arr[i] !== undefined) rv[i] = arr[i];
return rv;
}
// convert view.LightNodeInfo to Map object for easier transmission to Python backend
parseLightNodeInfo2Map(nodes_info) {

@ -438,6 +438,7 @@ onnx.Graph = class {
this._custom_add_node_io_idx = 0
this._custom_added_node = []
this._custom_added_outputs = []
// model parameter assignment here!
// console.log(graph)
@ -504,7 +505,8 @@ onnx.Graph = class {
}
get outputs() {
return this._outputs;
// return this._outputs;
return this._outputs.concat(this._custom_added_outputs);
}
get nodes() {
@ -632,6 +634,15 @@ onnx.Graph = class {
return custom_add_node;
}
reset_custom_added_outputs() {
this._custom_added_outputs = [];
}
add_output(name) {
const argument = this._context.argument(name);
this._custom_added_outputs.push(new onnx.Parameter(name, [ argument ]));
}
};
onnx.Parameter = class {

@ -24,6 +24,7 @@ grapher.Graph = class {
this._changedAttributes = new Map();
this._addedNode = new Map();
this._addedOutputs = [];
}
get options() {

@ -209,6 +209,9 @@ sidebar.NodeSidebar = class {
this._addButton('Recover Node');
this.add_separator(this._elements, 'sidebar-view-separator')
this._addButton('Enter');
this._addHeader('Output adding helper');
this._addButton('Add Output');
// deprecated
// this.add_separator(this._elements, 'sidebar-view-separator');
@ -272,8 +275,6 @@ sidebar.NodeSidebar = class {
}
}
}
}
render() {
@ -356,7 +357,11 @@ sidebar.NodeSidebar = class {
this._host._view._updateGraph()
});
}
if (title === 'Add Output') {
buttonElement.addEventListener('click', () => {
this._host._view._graph.add_output(this._modelNodeName)
});
}
}
// deprecated

@ -464,7 +464,6 @@ view.View = class {
this.refreshModelInputOutput()
this.refreshNodeArguments()
this.refreshNodeAttributes()
}
return active_graph
@ -580,7 +579,8 @@ view.View = class {
viewGraph._renameMap = this.lastViewGraph._renameMap;
viewGraph._changedAttributes = this.lastViewGraph._changedAttributes;
viewGraph._addedNode = this.lastViewGraph._addedNode;
viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey
viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey;
viewGraph._addedOutputs = this.lastViewGraph._addedOutputs;
// console.log(viewGraph._renameMap);
// console.log(viewGraph._modelNodeName2State)
}
@ -866,7 +866,7 @@ view.View = class {
}
// re-generate the added node according to _addedNode
// re-generate the added node according to _addedNode according to the latest _addedNode
refreshAddedNode() {
this._graphs[0].reset_custom_added_node()
// for (const node_info of this._addedNode.values()) {
@ -880,8 +880,7 @@ view.View = class {
for (const arg of input._arguments) {
input_list_names.push(arg.name)
}
this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names)
this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names)
}
for (const output of node.outputs) {
@ -896,7 +895,7 @@ view.View = class {
}
// re-fresh node arguments in case the node inputs/outputs are changed
refreshNodeArguments() {
refreshNodeArguments() {
for (var node of this._graphs[0]._nodes) {
if (this.lastViewGraph._renameMap.get(node.modelNodeName)) {
@ -979,7 +978,13 @@ view.View = class {
}
}
for (var output of this._graphs[0]._outputs) {
// create and add new output to graph
this._graphs[0].reset_custom_added_outputs();
for (var output_name of this.lastViewGraph._addedOutputs) {
this._graphs[0].add_output(output_name);
}
// console.log(this._graphs[0].outputs)
for (var output of this._graphs[0].outputs) {
var output_orig_name = output.arguments[0].original_name
if (this.lastViewGraph._renameMap.get('out_' + output_orig_name)) {
// for model input and output, node.modelNodeName == element.original_name
@ -1011,6 +1016,7 @@ view.View = class {
}
}
}
// console.log(this.lastViewGraph._renameMap)
}
}
}
@ -1190,7 +1196,6 @@ view.Graph = class extends grapher.Graph {
}
}
for (const output of graph.outputs) {
const viewOutput = this.createOutput(output);
for (const argument of output.arguments) {
@ -1238,6 +1243,17 @@ view.Graph = class extends grapher.Graph {
}
}
add_output(node_name) {
var model_node = this._modelNodeName2ModelNode.get(node_name);
for (var output of model_node.outputs) {
for (var argument of output.arguments) {
this._addedOutputs.push(argument.name);
}
}
// console.log(this._addedOutputs);
this.view._updateGraph();
}
resetGraph() {
// reset node states
for (const nodeId of this.nodes.keys()) {
@ -1277,14 +1293,14 @@ view.Graph = class extends grapher.Graph {
}
}
}
this._renameMap = new Map();
// clear custom added nodes
this._addedNode = new Map()
this.view._graphs[0].reset_custom_added_node()
this._addedOutputs = []
this.view._graphs[0].reset_custom_added_outputs()
}
recordRenameInfo(modelNodeName, src_name, dst_name) {

Loading…
Cancel
Save