changing the name of model input is done

1123
ZhangGe6 4 years ago
parent 43a693dd9d
commit b34157e345

@ -22,7 +22,7 @@ def open_model():
@app.route('/download', methods=['POST'])
def modify_and_download_model():
modify_info = request.get_json()
# print(modify_info)
print(modify_info)
onnx_modifier.reload() # allow downloading for multiple times

@ -45,6 +45,10 @@ class onnxModifier:
node_idx += 1
self.node_name2module[node.name] = node
for inp in self.graph.input:
self.node_name2module[inp.name] = inp
self.graph_input_names = [inp.name for inp in self.graph.input]
for out in self.graph.output:
self.node_name2module["out_" + out.name] = out # add `out_` in case the output has the same name with the last node
self.graph_output_names = ["out_" + out.name for out in self.graph.output]
@ -76,7 +80,7 @@ class onnxModifier:
# print('removing node {} ...'.format(node_name))
self.remove_node_by_name(node_name)
# remove node initializers (parameters) aka, keep and only keep the initializers of left nodes
# remove node initializers (parameters), aka, keep and only keep the initializers of left nodes
left_node_inputs = []
for left_node in self.graph.node:
left_node_inputs += left_node.input
@ -106,6 +110,12 @@ class onnxModifier:
for src_name, dst_name in renamed_ios.items():
# print(src_name, dst_name)
node = self.node_name2module[node_name]
if node_name in self.graph_input_names:
node.name = dst_name
# print(node.name)
# print(node)
pass
else:
# print(node.input, node.output)
for i in range(len(node.input)):
if node.input[i] == src_name:
@ -257,10 +267,10 @@ if __name__ == "__main__":
def test_modify_node_io_name():
node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}}
node_rename_io = {'input': {'input': 'inputd'}, 'Conv_0': {'input': 'inputd'}}
onnx_modifier.modify_node_io_name(node_rename_io)
onnx_modifier.check_and_save_model()
# test_modify_node_io_name()
test_modify_node_io_name()
def test_add_node():
node_info = {'custom_added_AveragePool0': {'properties': {'domain': 'ai.onnx', 'op_type': 'AveragePool', 'name': 'custom_added_AveragePool0'}, 'attributes': {'kernel_shape': [2, 2]}, 'inputs': {'X': ['fire2/squeeze1x1_1']}, 'outputs': {'Y': ['out']}}}
@ -278,7 +288,7 @@ if __name__ == "__main__":
onnx_modifier.modify_node_attr(changed_attr)
onnx_modifier.check_and_save_model()
test_change_node_attr()
# test_change_node_attr()

@ -801,7 +801,7 @@ class NodeAttributeView {
sidebar.ParameterView = class {
constructor(host, list, inp_or_out, param_idx, modelNodeName) {
constructor(host, list, param_type, param_idx, modelNodeName) {
this._host = host;
this._list = list;
this._modelNodeName = modelNodeName
@ -811,7 +811,7 @@ sidebar.ParameterView = class {
// console.log(list)
// for (const argument of list.arguments) {
for (const [arg_idx, argument] of list.arguments.entries()) {
const item = new sidebar.ArgumentView(host, argument, inp_or_out, param_idx, arg_idx, list._name, this._modelNodeName);
const item = new sidebar.ArgumentView(host, argument, param_type, param_idx, arg_idx, list._name, this._modelNodeName);
item.on('export-tensor', (sender, tensor) => {
this._raise('export-tensor', tensor);
});
@ -850,10 +850,10 @@ sidebar.ParameterView = class {
sidebar.ArgumentView = class {
constructor(host, argument, inp_or_out, param_index, arg_index, parameterName, modelNodeName) {
constructor(host, argument, param_type, param_index, arg_index, parameterName, modelNodeName) {
this._host = host;
this._argument = argument;
this._inp_or_out = inp_or_out
this._param_type = param_type
this._param_index = param_index
this._arg_index = arg_index
this._parameterName = parameterName
@ -907,7 +907,7 @@ sidebar.ArgumentView = class {
// console.log(e.target.value);
// this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value, this._argument._name);
// this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._arg_index, e.target.value);
this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._inp_or_out, this._param_index, this._arg_index, e.target.value);
this._host._view._graph.changeNodeInputOutput(this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, e.target.value);
// console.log(this._host._view._graph._renameMap);
});
this._element.appendChild(arg_input);
@ -1117,8 +1117,9 @@ sidebar.ModelSidebar = class {
}
if (Array.isArray(graph.inputs) && graph.inputs.length > 0) {
this._addHeader('Inputs');
for (const input of graph.inputs) {
this.addArgument(input.name, input);
// for (const input of graph.inputs) {
for (const [index, input] of graph.inputs.entries()){
this.addArgument(input.name, input, index);
}
}
if (Array.isArray(graph.outputs) && graph.outputs.length > 0) {
@ -1150,8 +1151,9 @@ sidebar.ModelSidebar = class {
this._elements.push(item.render());
}
addArgument(name, argument) {
const view = new sidebar.ParameterView(this._host, argument);
addArgument(name, argument, index) {
// const view = new sidebar.ParameterView(this._host, argument);
const view = new sidebar.ParameterView(this._host, argument, 'model_input', index, name);
view.toggle();
const item = new sidebar.NameValueView(this._host, name, view);
this._elements.push(item.render());

@ -13,9 +13,6 @@ var python = python || require('./python');
var sidebar = sidebar || require('./view-sidebar');
var grapher = grapher || require('./view-grapher');
// var onnx = onnx || require('./onnx');
view.View = class {
constructor(host, id) {
@ -464,7 +461,9 @@ view.View = class {
var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null;
if (active_graph && this.lastViewGraph) {
this.refreshAddedNode()
this.refreshModelInputOutput()
this.refreshNodeArguments()
this.refreshNodeAttributes()
}
@ -910,14 +909,14 @@ view.View = class {
if (this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)) {
// console.log(element.name)
var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
console.log(element.original_name)
console.log(new_name)
// console.log(element.original_name)
// console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
input.arguments[index] = arg_with_new_name
console.log(arg_with_new_name)
console.log(node)
// console.log(arg_with_new_name)
// console.log(node)
}
}
}
@ -942,6 +941,9 @@ view.View = class {
}
}
}
refreshNodeAttributes() {
for (const node_name of this.lastViewGraph._changedAttributes.keys()) {
var attr_change_map = this.lastViewGraph._changedAttributes.get(node_name)
var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name)
@ -953,7 +955,62 @@ view.View = class {
}
}
}
refreshModelInputOutput() {
// console.log("refreshModelInputOutput", this._graphs[0])
// console.log(this.lastViewGraph._renameMap)
for (var input of this._graphs[0]._inputs) {
// console.log(input)
// console.log(input.modelNodeName)
if (this.lastViewGraph._renameMap.get(input.modelNodeName)) {
// for model input and output, node.modelNodeName == element.original_name
var new_name = this.lastViewGraph._renameMap.get(input.modelNodeName).get(input.modelNodeName)
// console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, input.modelNodeName)
input.arguments[0] = arg_with_new_name
// change all the name of node input linked with model input meanwhile
for (var node of this._graphs[0]._nodes) {
// this node has some changed arguments
// console.log(node)
// console.log(node.modelNodeName)
// if (this.lastViewGraph._renameMap.get(node.modelNodeName)) {
for (var node_input of node.inputs) {
for (const [index, element] of node_input.arguments.entries()) {
// console.log(element.name, input.modelNodeName)
// if (element.name == input.modelNodeName) {
if (element.original_name == input.modelNodeName) {
// console.log(element.name)
// var new_name = this.lastViewGraph._renameMap.get(node.modelNodeName).get(element.original_name)
// console.log(element.original_name)
// console.log(new_name)
var arg_with_new_name = this._graphs[0]._context.argument(new_name, element.original_name)
node_input.arguments[index] = arg_with_new_name
// save the changed name into _renameMap
// as this modified _renamedMap, so refreshModelInputOutput() shoulf be called before refreshNodeArguments()
if (!this.lastViewGraph._renameMap.get(node.modelNodeName)) {
this.lastViewGraph._renameMap.set(node.modelNodeName, new Map());
}
var orig_arg_name = element.original_name
this.lastViewGraph._renameMap.get(node.modelNodeName).set(orig_arg_name, new_name);
// console.log(arg_with_new_name)
// console.log(node)
}
}
}
// }
}
}
}
}
};
@ -986,6 +1043,7 @@ view.Graph = class extends grapher.Graph {
const value = new view.Input(this, input);
// value.name = (this._nodeKey++).toString();
value.name = input.name; // input nodes should have name
input.modelNodeName = input.name;
this.setNode(value);
return value;
}
@ -995,6 +1053,7 @@ view.Graph = class extends grapher.Graph {
const value = new view.Output(this, output, modelNodeName);
// value.name = (this._nodeKey++).toString();
value.name = "out_" + output.name; // output nodes should have name
output.modelNodeName = "out_" + output.name;
this.setNode(value);
return value;
}
@ -1175,6 +1234,13 @@ view.Graph = class extends grapher.Graph {
for (const changed_node_name of this._renameMap.keys()) {
var node = this._modelNodeName2ModelNode.get(changed_node_name)
console.log(node)
// console.log(typeof node)
// console.log(node.constructor.name)
if (node.arguments) { // model input or model output. Because they are purely onnx.Parameter
node.arguments[0] = this.view._graphs[0]._context.argument(node.modelNodeName)
}
else { // model nodes
//reset inputs
for (var input of node.inputs) {
for (var i = 0; i < input.arguments.length; ++i) {
@ -1193,6 +1259,10 @@ view.Graph = class extends grapher.Graph {
}
}
}
}
}
this._renameMap = new Map();
@ -1241,7 +1311,7 @@ view.Graph = class extends grapher.Graph {
}
changeNodeInputOutput(modelNodeName, parameterName, inp_or_out, param_index, arg_index, targetValue, orig_arg_name) {
changeNodeInputOutput(modelNodeName, parameterName, param_type, param_index, arg_index, targetValue, orig_arg_name) {
// changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) {
if (this._addedNode.has(modelNodeName)) { // for custom added node
if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) {
@ -1260,11 +1330,22 @@ view.Graph = class extends grapher.Graph {
this._renameMap.set(modelNodeName, new Map());
}
if (inp_or_out == 'input') {
if (param_type == 'model_input' || param_type == 'model_output') {
// var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).arguments[arg_index].orig_arg_name
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments)
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0])
// console.log(this._modelNodeName2ModelNode.get(modelNodeName).arguments[0].orig_arg_name)
// console.log("changing model_input", orig_arg_name)
var orig_arg_name = modelNodeName
// console.log("changing model_input", orig_arg_name)
}
if (param_type == 'input') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).inputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}
if (inp_or_out == 'output') {
if (param_type == 'output') {
var orig_arg_name = this._modelNodeName2ModelNode.get(modelNodeName).outputs[param_index].arguments[arg_index].original_name
// console.log(orig_arg_name)
}

Loading…
Cancel
Save