changing the name of model input is done

1123
ZhangGe6 4 years ago
parent 43a693dd9d
commit b34157e345

@ -22,7 +22,7 @@ def open_model():
@app.route('/download', methods=['POST']) @app.route('/download', methods=['POST'])
def modify_and_download_model(): def modify_and_download_model():
modify_info = request.get_json() modify_info = request.get_json()
# print(modify_info) print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times onnx_modifier.reload() # allow downloading for multiple times

@ -45,6 +45,10 @@ class onnxModifier:
node_idx += 1 node_idx += 1
self.node_name2module[node.name] = node self.node_name2module[node.name] = node
for inp in self.graph.input:
self.node_name2module[inp.name] = inp
self.graph_input_names = [inp.name for inp in self.graph.input]
for out in self.graph.output: for out in self.graph.output:
self.node_name2module["out_" + out.name] = out # add `out_` in case the output has the same name with the last node self.node_name2module["out_" + out.name] = out # add `out_` in case the output has the same name with the last node
self.graph_output_names = ["out_" + out.name for out in self.graph.output] self.graph_output_names = ["out_" + out.name for out in self.graph.output]
@ -76,7 +80,7 @@ class onnxModifier:
# print('removing node {} ...'.format(node_name)) # print('removing node {} ...'.format(node_name))
self.remove_node_by_name(node_name) self.remove_node_by_name(node_name)
# remove node initializers (parameters) aka, keep and only keep the initializers of left nodes # remove node initializers (parameters), aka, keep and only keep the initializers of left nodes
left_node_inputs = [] left_node_inputs = []
for left_node in self.graph.node: for left_node in self.graph.node:
left_node_inputs += left_node.input left_node_inputs += left_node.input
@ -106,6 +110,12 @@ class onnxModifier:
for src_name, dst_name in renamed_ios.items(): for src_name, dst_name in renamed_ios.items():
# print(src_name, dst_name) # print(src_name, dst_name)
node = self.node_name2module[node_name] node = self.node_name2module[node_name]
if node_name in self.graph_input_names:
node.name = dst_name
# print(node.name)
# print(node)
pass
else:
# print(node.input, node.output) # print(node.input, node.output)
for i in range(len(node.input)): for i in range(len(node.input)):
if node.input[i] == src_name: if node.input[i] == src_name:
@ -257,10 +267,10 @@ if __name__ == "__main__":
def test_modify_node_io_name(): def test_modify_node_io_name():
node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}} node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}}
onnx_modifier.modify_node_io_name(node_rename_io) onnx_modifier.modify_node_io_name(node_rename_io)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
# test_modify_node_io_name() test_modify_node_io_name()
def test_add_node(): def test_add_node():
node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}} node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}}
@ -278,7 +288,7 @@ if __name__ == "__main__":
onnx_modifier.modify_node_attr(changed_attr) onnx_modifier.modify_node_attr(changed_attr)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
test_change_node_attr() # test_change_node_attr()

@ -801,7 +801,7 @@ class NodeAttributeView {
sidebar.ParameterView = class { sidebar.ParameterView = class {
constructor(host, list, inp_or_out, param_idx, modelNodeName) { constructor(host, list, param_type, param_idx, modelNodeName) {
this._host = host; this._host = host;
this._list = list; this._list = list;
this._modelNodeName = modelNodeName this._modelNodeName = modelNodeName
@ -811,7 +811,7 @@ sidebar.ParameterView = class {
// console.log(list) // console.log(list)
// for (const argument of list.arguments) { // for (const argument of list.arguments) {
for (const [arg_idx, argument] of list.arguments.entries()) { for (const [arg_idx, argument] of list.arguments.entries()) {
const item = new sidebar.ArgumentView(host, argument, inp_or_out, param_idx, arg_idx, list._name, this._modelNodeName); const item = new sidebar.ArgumentView(host, argument, param_type, param_idx, arg_idx, list._name, this._modelNodeName);
item.on('export-tensor', (sender, tensor) => { item.on('export-tensor', (sender, tensor) => {
this._raise('export-tensor', tensor); this._raise('export-tensor', tensor);
}); });
@ -850,10 +850,10 @@ sidebar.ParameterView = class {
sidebar.ArgumentView = class { sidebar.ArgumentView = class {
constructor(host, argument, inp_or_out, param_index, arg_index, parameterName, modelNodeName) { constructor(host, argument, param_type, param_index, arg_index, parameterName, modelNodeName) {
this._host = host; this._host = host;
this._argument = argument; this._argument = argument;
this._inp_or_out = inp_or_out this._param_type = param_type
this._param_index = param_index this._param_index = param_index
this._arg_index = arg_index this._arg_index = arg_index
this._parameterName = parameterName this._parameterName = parameterName
@ -907,7 +907,7 @@ sidebar.ArgumentView = class {
// console.log(e.target.value); // console.log(e.target.value);
// this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name); // this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name);
// this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value); // this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value);
this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._inp_or_out, this._param_index, this._arg_index, e.target.value); this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, e.target.value);
// console.log(this._host._view._graph._renameMap); // console.log(this._host._view._graph._renameMap);
}); });
this._element.appendChild(arg_input); this._element.appendChild(arg_input);
@ -1117,8 +1117,9 @@ sidebar.ModelSidebar = class {
} }
if (Array.isArray(graph.inputs) && graph.inputs.length > 0) { if (Array.isArray(graph.inputs) && graph.inputs.length > 0) {
this._addHeader('Inputs'); this._addHeader('Inputs');
for (const input of graph.inputs) { // for (const input of graph.inputs) {
this.addArgument(input.name, input); for (const [index, input] of graph.inputs.entries()){
this.addArgument(input.name, input, index);
} }
} }
if (Array.isArray(graph.outputs) && graph.outputs.length > 0) { if (Array.isArray(graph.outputs) && graph.outputs.length > 0) {
@ -1150,8 +1151,9 @@ sidebar.ModelSidebar = class {
this._elements.push(item.render()); this._elements.push(item.render());
} }
addArgument(name, argument) { addArgument(name, argument, index) {
const view = new sidebar.ParameterView(this._host, argument); // const view = new sidebar.ParameterView(this._host, argument);
const view = new sidebar.ParameterView(this._host, argument, 'model_input', index, name);
view.toggle(); view.toggle();
const item = new sidebar.NameValueView(this._host, name, view); const item = new sidebar.NameValueView(this._host, name, view);
this._elements.push(item.render()); this._elements.push(item.render());

@ -13,9 +13,6 @@ var python = python || require('./python');
var sidebar = sidebar || require('./view-sidebar'); var sidebar = sidebar || require('./view-sidebar');
var grapher = grapher || require('./view-grapher'); var grapher = grapher || require('./view-grapher');
// var onnx = onnx || require('./onnx');
view.View = class { view.View = class {
constructor(host, id) { constructor(host, id) {
@ -464,7 +461,9 @@ view.View = class {
var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null;
if (active_graph && this.lastViewGraph) { if (active_graph && this.lastViewGraph) {
this.refreshAddedNode() this.refreshAddedNode()
this.refreshModelInputOutput()
this.refreshNodeArguments() this.refreshNodeArguments()
this.refreshNodeAttributes()
} }
@ -910,14 +909,14 @@ view.View = class {
if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) { if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) {
// console.log(element.name) // console.log(element.name)
var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name) var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
console.log(element.original_name) // console.log(element.original_name)
console.log(new_name) // console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name) var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
input.arguments[index] = arg_with_new_name input.arguments[index] = arg_with_new_name
console.log(arg_with_new_name) // console.log(arg_with_new_name)
console.log(node) // console.log(node)
} }
} }
} }
@ -942,6 +941,9 @@ view.View = class {
} }
} }
}
refreshNodeAttributes() {
for (const node_name of this.lastViewGraph._changedAttributes.keys()) { for (const node_name of this.lastViewGraph._changedAttributes.keys()) {
var attr_change_map = this.lastViewGraph._changedAttributes.get(node_name) var attr_change_map = this.lastViewGraph._changedAttributes.get(node_name)
var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name) var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name)
@ -953,7 +955,62 @@ view.View = class {
} }
} }
}
refreshModelInputOutput() {
// console.log("refreshModelInputOutput", this._graphs[0])
// console.log(this.lastViewGraph._renameMap)
for (var input of this._graphs[0]._inputs) {
// console.log(input)
// console.log(input.modelNodeName)
if (this.lastViewGraph._renameMap.get(input.modelNodeName)) {
// for model input and output, node.modelNodeName == element.original_name
var new_name = this.lastViewGraph._renameMap.get(input.modelNodeName).get(input.modelNodeName)
// console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, input.modelNodeName)
input.arguments[0] = arg_with_new_name
// change all the name of node input linked with model input meanwhile
for (var node of this._graphs[0]._nodes) {
// this node has some changed arguments
// console.log(node)
// console.log(node.modelNodeName)
// if (this.lastViewGraph._renameMap.get(node.modelNodeName)) {
for (var node_input of node.inputs) {
for (const [index, element] of node_input.arguments.entries()) {
// console.log(element.name, input.modelNodeName)
// if (element.name == input.modelNodeName) {
if (element.original_name == input.modelNodeName) {
// console.log(element.name)
// var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
// console.log(element.original_name)
// console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
node_input.arguments[index] = arg_with_new_name
// save the changed name into _renameMap
// as this modified _renamedMap, so refreshModelInputOutput() shoulf be called before refreshNodeArguments()
if (!this.lastViewGraph._renameMap.get(node.modelNodeName)) {
this.lastViewGraph._renameMap.set(node.modelNodeName, new Map());
}
var orig_arg_name = element.original_name
this.lastViewGraph._renameMap.get(node.modelNodeName).set(orig_arg_name, new_name);
// console.log(arg_with_new_name)
// console.log(node)
}
}
}
// }
}
}
}
} }
}; };
@ -986,6 +1043,7 @@ view.Graph = class extends grapher.Graph {
const value = new view.Input(this, input); const value = new view.Input(this, input);
// value.name = (this._nodeKey++).toString(); // value.name = (this._nodeKey++).toString();
value.name = input.name; // input nodes should have name value.name = input.name; // input nodes should have name
input.modelNodeName = input.name;
this.setNode(value); this.setNode(value);
return value; return value;
} }
@ -995,6 +1053,7 @@ view.Graph = class extends grapher.Graph {
const value = new view.Output(this, output, modelNodeName); const value = new view.Output(this, output, modelNodeName);
// value.name = (this._nodeKey++).toString(); // value.name = (this._nodeKey++).toString();
value.name = "out_" + output.name; // output nodes should have name value.name = "out_" + output.name; // output nodes should have name
output.modelNodeName = "out_" + output.name;
this.setNode(value); this.setNode(value);
return value; return value;
} }
@ -1175,6 +1234,13 @@ view.Graph = class extends grapher.Graph {
for (const changed_node_name of this._renameMap.keys()) { for (const changed_node_name of this._renameMap.keys()) {
var node = this._modelNodeName2ModelNode.get(changed_node_name) var node = this._modelNodeName2ModelNode.get(changed_node_name)
console.log(node) console.log(node)
// console.log(typeof node)
// console.log(node.constructor.name)
if (node.arguments) { // model input or model output. Because they are purely onnx.Parameter
node.arguments[0] = this.view._graphs[0]._context.argument(node.modelNodeName)
}
else { // model nodes
//reset inputs //reset inputs
for (var input of node.inputs) { for (var input of node.inputs) {
for (var i = 0; i < input.arguments.length; ++i) { for (var i = 0; i < input.arguments.length; ++i) {
@ -1193,6 +1259,10 @@ view.Graph = class extends grapher.Graph {
} }
} }
} }
}
} }
this._renameMap = new Map(); this._renameMap = new Map();
@ -1241,7 +1311,7 @@ view.Graph = class extends grapher.Graph {
} }
changeNodeInputOutput(modelNodeName, parameterName, inp_or_out, param_index, arg_index, targetValue, orig_arg_name) { changeNodeInputOutput(modelNodeName, parameterName, param_type, param_index, arg_index, targetValue, orig_arg_name) {
// changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { // changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) {
if (this._addedNode.has(modelNodeName)) { // for custom added node if (this._addedNode.has(modelNodeName)) { // for custom added node
if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) {
@ -1260,11 +1330,22 @@ view.Graph = class extends grapher.Graph {
this._renameMap.set(modelNodeName, new Map()); this._renameMap.set(modelNodeName, new Map());
} }
if (inp_or_out == 'input') { if (param_type == 'model_input' || param_type == 'model_output') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[arg_index].orig_arg_name
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments)
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0])
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].orig_arg_name)
// console.log("changing model_input", orig_arg_name)
var orig_arg_name = modelNodeName
// console.log("changing model_input", orig_arg_name)
}
if (param_type == 'input') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name) // console.log(orig_arg_name)
} }
if (inp_or_out == 'output') { if (param_type == 'output') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name) // console.log(orig_arg_name)
} }

Loading…
Cancel
Save