modify_node_io_name() basically done

1123
ZhangGe6 4 years ago
parent ed093dcc86
commit 323c85fb1f

@ -20,11 +20,12 @@ def return_file():
@app.route('/download', methods=['POST'])
def modify_and_download_model():
node_states = json.loads(request.get_json())
modify_info = request.get_json()
# print(modify_info)
# print(node_states)
onnx_modifier.reload() # allow for downloading for multiple times
onnx_modifier.remove_node_by_node_states(node_states)
onnx_modifier.remove_node_by_node_states(modify_info['node_states'])
onnx_modifier.modify_node_io_name(modify_info['node_renamed_io'])
onnx_modifier.check_and_save_model()

@ -10,6 +10,7 @@ class onnxModifier:
def __init__(self, model_name, model_proto):
self.model_name = model_name
self.model_proto_backup = model_proto
self.reload()
@classmethod
def from_model_path(cls, model_path):
@ -78,12 +79,29 @@ class onnxModifier:
for init_name in self.initilizer_name2module.keys():
if not init_name in left_node_inputs:
self.initializer.remove(self.initilizer_name2module[init_name])
def modify_node_io_name(self, node_renamed_io):
# print(node_renamed_io)
for node_name in node_renamed_io.keys():
renamed_ios = node_renamed_io[node_name]
for src_name, dst_name in renamed_ios.items():
# print(src_name, dst_name)
node = self.node_name2module[node_name]
# print(node.input, node.output)
for i in range(len(node.input)):
if node.input[i] == src_name:
node.input[i] = dst_name
for i in range(len(node.output)):
if node.output[i] == src_name:
node.output[i] = dst_name
# print(node.input, node.output)
def check_and_save_model(self, save_dir='./res_onnx'):
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
# onnx.checker.check_model(self.model_proto)
onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path)
def inference(self):
# model_proto_bytes = onnx._serialize(model_proto_from_stream)
# inference_session = rt.InferenceSession(model_proto_bytes)
@ -91,9 +109,9 @@ class onnxModifier:
if __name__ == "__main__":
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-3.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\tflite_sim.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
def remove_node_by_node_states():
@ -134,7 +152,7 @@ if __name__ == "__main__":
print(inp.name)
onnx_modifier.check_and_save_model()
remove_node_by_node_states()
# remove_node_by_node_states()
def explore_basic():
print(type(onnx_modifier.model_proto.graph.initializer))
@ -143,15 +161,20 @@ if __name__ == "__main__":
print(len(onnx_modifier.model_proto.graph.node))
print(len(onnx_modifier.model_proto.graph.initializer))
# for node in onnx_modifier.model_proto.graph.node:
# print(node.name)
# print(node.input)
# print()
for node in onnx_modifier.model_proto.graph.node:
print(node.name)
print(node.input)
print()
# for initializer in onnx_modifier.model_proto.graph.initializer:
# print(initializer.name)
# print(onnx_modifier.model_proto.graph.initializer['fire9/concat_1_scale'])
pass
# explore_basic()
def test_modify_node_io_name():
node_rename_io = {'Conv3': {'pool1_1': 'conv1_1'}}
onnx_modifier.modify_node_io_name(node_rename_io)
onnx_modifier.check_and_save_model()
test_modify_node_io_name()

@ -214,7 +214,8 @@ host.BrowserHost = class {
const downloadButton = this.document.getElementById('download-graph');
downloadButton.addEventListener('click', () => {
// console.log(this._host._view._graph._modelNodeName2State)
console.log(this._view._graph._modelNodeName2State)
console.log(this._view._graph._renameMap)
// https://healeycodes.com/talking-between-languages
fetch('/download', {
// Declare what type of data we're sending
@ -224,8 +225,13 @@ host.BrowserHost = class {
// Specify the method
method: 'POST',
// https://blog.csdn.net/Crazy_SunShine/article/details/80624366
body: JSON.stringify(
this._mapToJson(this._view._graph._modelNodeName2State)
body: JSON.stringify({
// 'node_states' : this._mapToJson(this._view._graph._modelNodeName2State),
// 'node_renamed_io' : this._twoLevelMapToJson(this._view._graph._renameMap),
'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State),
'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap),
}
)
}).then(function (response) {
return response.text();
@ -605,16 +611,33 @@ host.BrowserHost = class {
_strMapToObj(strMap){
let obj = Object.create(null);
for (let[k, v] of strMap) {
for (let [k, v] of strMap) {
obj[k] = v;
}
return obj;
}
// {key1:val1, key2:val2, ...} => Json
_mapToJson(map) {
return JSON.stringify(this._strMapToObj(map));
}
// https://www.xul.fr/javascript/map-and-object.php
mapToObjectRec(m) {
let lo = {}
for(let[k,v] of m) {
if(v instanceof Map) {
lo[k] = this.mapToObjectRec(v)
}
else {
lo[k] = v
}
}
return lo
}
};
host.Dropdown = class {

@ -247,7 +247,7 @@ sidebar.NodeSidebar = class {
newNameElement.setAttribute('value', this._host._view._graph._renameMap.get(this._modelNodeName).get(argument.name));
}
newNameElement.addEventListener('input', (e) => {
// console.log(e.target.value);
console.log(e.target.value);
this._host._view._graph.recordRenameInfo(this._modelNodeName, argument.name, e.target.value);
// console.log(this._host._view._graph._renameMap);

@ -958,7 +958,8 @@ view.Graph = class extends grapher.Graph {
// if this argument has been renamed
if (
this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument.name)
this._renameMap.get(viewNode.modelNodeName).get(argument.name) &&
!this._renameMap.get(viewNode.modelNodeName).get(argument.name) == '' // in case user clear the input name
)
{
// argument.name = this._renameMap.get(viewNode.modelNodeName).get(argument.name);
@ -992,7 +993,8 @@ view.Graph = class extends grapher.Graph {
// if this argument has been renamed
if (
this._renameMap.get(viewNode.modelNodeName) &&
this._renameMap.get(viewNode.modelNodeName).get(argument.name)
this._renameMap.get(viewNode.modelNodeName).get(argument.name) &&
!this._renameMap.get(viewNode.modelNodeName).get(argument.name) == ''
)
{
// console.log(argument.name)

Loading…
Cancel
Save