modify_node_io_name() basically done

1123
ZhangGe6 4 years ago
parent ed093dcc86
commit 323c85fb1f

@ -20,11 +20,12 @@ def return_file():
@app.route('/download', methods=['POST']) @app.route('/download', methods=['POST'])
def modify_and_download_model(): def modify_and_download_model():
node_states = json.loads(request.get_json()) modify_info = request.get_json()
# print(modify_info)
# print(node_states)
onnx_modifier.reload() # allow for downloading for multiple times onnx_modifier.reload() # allow for downloading for multiple times
onnx_modifier.remove_node_by_node_states(node_states) onnx_modifier.remove_node_by_node_states(modify_info['node_states'])
onnx_modifier.modify_node_io_name(modify_info['node_renamed_io'])
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()

@ -10,6 +10,7 @@ class onnxModifier:
def __init__(self, model_name, model_proto): def __init__(self, model_name, model_proto):
self.model_name = model_name self.model_name = model_name
self.model_proto_backup = model_proto self.model_proto_backup = model_proto
self.reload()
@classmethod @classmethod
def from_model_path(cls, model_path): def from_model_path(cls, model_path):
@ -78,12 +79,29 @@ class onnxModifier:
for init_name in self.initilizer_name2module.keys(): for init_name in self.initilizer_name2module.keys():
if not init_name in left_node_inputs: if not init_name in left_node_inputs:
self.initializer.remove(self.initilizer_name2module[init_name]) self.initializer.remove(self.initilizer_name2module[init_name])
def modify_node_io_name(self, node_renamed_io):
# print(node_renamed_io)
for node_name in node_renamed_io.keys():
renamed_ios = node_renamed_io[node_name]
for src_name, dst_name in renamed_ios.items():
# print(src_name, dst_name)
node = self.node_name2module[node_name]
# print(node.input, node.output)
for i in range(len(node.input)):
if node.input[i] == src_name:
node.input[i] = dst_name
for i in range(len(node.output)):
if node.output[i] == src_name:
node.output[i] = dst_name
# print(node.input, node.output)
def check_and_save_model(self, save_dir='./res_onnx'): def check_and_save_model(self, save_dir='./res_onnx'):
save_path = os.path.join(save_dir, 'modified_' + self.model_name) save_path = os.path.join(save_dir, 'modified_' + self.model_name)
# onnx.checker.check_model(self.model_proto) onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path) onnx.save(self.model_proto, save_path)
def inference(self): def inference(self):
# model_proto_bytes = onnx._serialize(model_proto_from_stream) # model_proto_bytes = onnx._serialize(model_proto_from_stream)
# inference_session = rt.InferenceSession(model_proto_bytes) # inference_session = rt.InferenceSession(model_proto_bytes)
@ -91,9 +109,9 @@ class onnxModifier:
if __name__ == "__main__": if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx" model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path) onnx_modifier = onnxModifier.from_model_path(model_path)
def remove_node_by_node_states(): def remove_node_by_node_states():
@ -134,7 +152,7 @@ if __name__ == "__main__":
print(inp.name) print(inp.name)
onnx_modifier.check_and_save_model() onnx_modifier.check_and_save_model()
remove_node_by_node_states() # remove_node_by_node_states()
def explore_basic(): def explore_basic():
print(type(onnx_modifier.model_proto.graph.initializer)) print(type(onnx_modifier.model_proto.graph.initializer))
@ -143,15 +161,20 @@ if __name__ == "__main__":
print(len(onnx_modifier.model_proto.graph.node)) print(len(onnx_modifier.model_proto.graph.node))
print(len(onnx_modifier.model_proto.graph.initializer)) print(len(onnx_modifier.model_proto.graph.initializer))
# for node in onnx_modifier.model_proto.graph.node: for node in onnx_modifier.model_proto.graph.node:
# print(node.name) print(node.name)
# print(node.input) print(node.input)
# print() print()
# for initializer in onnx_modifier.model_proto.graph.initializer: # for initializer in onnx_modifier.model_proto.graph.initializer:
# print(initializer.name) # print(initializer.name)
# print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale']) # print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale'])
pass
# explore_basic() # explore_basic()
def test_modify_node_io_name():
node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}}
onnx_modifier.modify_node_io_name(node_rename_io)
onnx_modifier.check_and_save_model()
test_modify_node_io_name()

@ -214,7 +214,8 @@ host.BrowserHost = class {
const downloadButton = this.document.getElementById('download-graph'); const downloadButton = this.document.getElementById('download-graph');
downloadButton.addEventListener('click', () => { downloadButton.addEventListener('click', () => {
// console.log(this._host._view._graph._modelNodeName2State) console.log(this._view._graph._modelNodeName2State)
console.log(this._view._graph._renameMap)
// https://healeycodes.com/talking-between-languages // https://healeycodes.com/talking-between-languages
fetch('/download', { fetch('/download', {
// Declare what type of data we're sending // Declare what type of data we're sending
@ -224,8 +225,13 @@ host.BrowserHost = class {
// Specify the method // Specify the method
method: 'POST', method: 'POST',
// https://blog.csdn.net/Crazy_SunShine/article/details/80624366 // https://blog.csdn.net/Crazy_SunShine/article/details/80624366
body: JSON.stringify( body: JSON.stringify({
this._mapToJson(this._view._graph._modelNodeName2State) // 'node_states' : this._mapToJson(this._view._graph._modelNodeName2State),
// 'node_renamed_io' : this._twoLevelMapToJson(this._view._graph._renameMap),
'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State),
'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap),
}
) )
}).then(function (response) { }).then(function (response) {
return response.text(); return response.text();
@ -605,16 +611,33 @@ host.BrowserHost = class {
_strMapToObj(strMap){ _strMapToObj(strMap){
let obj = Object.create(null); let obj = Object.create(null);
for (let[k, v] of strMap) { for (let [k, v] of strMap) {
obj[k] = v; obj[k] = v;
} }
return obj; return obj;
} }
// {key1:val1, key2:val2, ...} => Json
_mapToJson(map) { _mapToJson(map) {
return JSON.stringify(this._strMapToObj(map)); return JSON.stringify(this._strMapToObj(map));
} }
// https://www.xul.fr/javascript/map-and-object.php
mapToObjectRec(m) {
let lo = {}
for(let[k,v] of m) {
if(v instanceof Map) {
lo[k] = this.mapToObjectRec(v)
}
else {
lo[k] = v
}
}
return lo
}
}; };
host.Dropdown = class { host.Dropdown = class {

@ -247,7 +247,7 @@ sidebar.NodeSidebar = class {
newNameElement.setAttribute('value', this._host._view._graph._renameMap.get(this._modelNodeName).get(argument.name)); newNameElement.setAttribute('value', this._host._view._graph._renameMap.get(this._modelNodeName).get(argument.name));
} }
newNameElement.addEventListener('input', (e) => { newNameElement.addEventListener('input', (e) => {
// console.log(e.target.value); console.log(e.target.value);
this._host._view._graph.recordRenameInfo(this._modelNodeName, argument.name, e.target.value); this._host._view._graph.recordRenameInfo(this._modelNodeName, argument.name, e.target.value);
// console.log(this._host._view._graph._renameMap); // console.log(this._host._view._graph._renameMap);

@ -958,7 +958,8 @@ view.Graph = class extends grapher.Graph {
// if this argument has been renamed // if this argument has been renamed
if ( if (
this._renameMap.get(viewNode.modelNodeName) && this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument.name) this._renameMap.get(viewNode.modelNodeName).get(argument.name) &&
!this._renameMap.get(viewNode.modelNodeName).get(argument.name) == '' // in case user clear the input name
) )
{ {
// argument.name = this._renameMap.get(viewNode.modelNodeName).get(argument.name); // argument.name = this._renameMap.get(viewNode.modelNodeName).get(argument.name);
@ -992,7 +993,8 @@ view.Graph = class extends grapher.Graph {
// if this argument has been renamed // if this argument has been renamed
if ( if (
this._renameMap.get(viewNode.modelNodeName) && this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument.name) this._renameMap.get(viewNode.modelNodeName).get(argument.name) &&
!this._renameMap.get(viewNode.modelNodeName).get(argument.name) == ''
) )
{ {
// console.log(argument.name) // console.log(argument.name)

Loading…
Cancel
Save