diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b4dd87e --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +gym/ + +*.onnx \ No newline at end of file diff --git a/app.py b/app.py index da341ba..bfa6f5e 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,6 @@ from flask import Flask, render_template, request -import onnx -import onnxruntime as rt +import json +from onnx_modifier import onnxModifier app = Flask(__name__) @app.route('/') @@ -11,27 +11,20 @@ def index(): def return_file(): # https://blog.miguelgrinberg.com/post/handling-file-uploads-with-flask onnx_file = request.files['file'] - onnx_file_name = onnx_file.filename - onnx_file_stream = onnx_file.stream - - # onnx_file_stream.seek(0) - # model_proto_from_stream = onnx.load_model(onnx_file_stream, onnx.ModelProto) - # print(model_proto_from_stream.graph) - # model_proto_bytes = onnx._serialize(model_proto_from_stream) - # inference_session = rt.InferenceSession(model_proto_bytes) - # onnx.save_model(model_proto_from_binary_stream, onnx_file.filename) + global onnx_modifier + onnx_modifier = onnxModifier.from_name_stream(onnx_file.filename, onnx_file.stream) return 'OK', 200 @app.route('/download', methods=['POST']) def modify_and_download_model(): - # modelNodeStates = request.get_json() - # print(modelNodeStates) - - + node_states = json.loads(request.get_json()) + # print(modelNodeStates) + onnx_modifier.remove_node_by_node_states(node_states) + onnx_modifier.check_and_save_model() return 'OK', 200 diff --git a/onnx_modifier.py b/onnx_modifier.py index 2df4e44..d55b93c 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -1,7 +1,8 @@ # https://leimao.github.io/blog/ONNX-Python-API/ # https://github.com/saurabh-shandilya/onnx-utils - +import io import os +from platform import node import onnx class onnxModifier: @@ -15,15 +16,22 @@ class onnxModifier: def gen_node_name2module_map(self): self.node_name2module = dict() node_idx = 0 + for node in self.graph.input: + node_idx += 1 + self.node_name2module[node.name] = node + for node in self.graph.node: if node.name == '': node.name = str(node.op_type) + str(node_idx) node_idx += 1 - self.node_name2module[node.name] = node - - # print(self.node_name2module.keys()) + for node in self.graph.output: + node_idx += 1 + self.node_name2module[node.name] = node + self.graph_output_names = [node.name for node in self.graph.output] + # print(self.node_name2module.keys()) + @classmethod def from_model_path(cls, model_path): model_name = os.path.basename(model_path) @@ -33,26 +41,54 @@ class onnxModifier: @classmethod def from_name_stream(cls, name, stream): # https://leimao.github.io/blog/ONNX-IO-Stream/ + stream.seek(0) model_proto = onnx.load_model(stream, onnx.ModelProto) return cls(name, model_proto) def remove_node_by_name(self, node_name): self.graph.node.remove(self.node_name2module[node_name]) + + def remove_output_by_name(self, node_name): + self.graph.output.remove(self.node_name2module[node_name]) + + def remove_node_by_node_states(self, node_states): + for node_name, node_state in node_states.items(): + if node_state == 'Deleted': + if node_name in self.graph_output_names: + self.remove_output_by_name(node_name) + else: + self.remove_node_by_name(node_name) - def check_and_save_model(self, save_dir='./'): + def check_and_save_model(self, save_dir='./res_onnx'): save_path = os.path.join(save_dir, 'modified_' + self.model_name) onnx.checker.check_model(self.model_proto) onnx.save(self.model_proto, save_path) + + def inference(self): + # model_proto_bytes = onnx._serialize(model_proto_from_stream) + # inference_session = rt.InferenceSession(model_proto_bytes) + pass if __name__ == "__main__": model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" - onnx_modifer = onnxModifier.from_model_path(model_path) - onnx_modifer.remove_node_by_name('Softmax_nc_rename_64') - onnx_modifer.check_and_save_model() + onnx_modifier = onnxModifier.from_model_path(model_path) + onnx_modifier.remove_node_by_name('Softmax_nc_rename_64') + onnx_modifier.remove_output_by_name('softmaxout_1') + # onnx_modifier.graph.output.remove(onnx_modifier.node_name2module['softmaxout_1']) + onnx_modifier.check_and_save_model() + + # print(onnx_modifier.graph.input) + # print(onnx_modifier.graph.output) + # print(onnx_modifier.node_name2module['Softmax_nc_rename_64']) + # print(onnx_modifier.node_name2module['softmaxout_1']) + # onnx_modifier.remove_node_by_name('softmaxout_1') + # for node in onnx_modifier.graph.output: + # print(node.name) + diff --git a/static/view-sidebar.js b/static/view-sidebar.js index de6ec35..84cfe06 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -322,10 +322,9 @@ sidebar.NodeSidebar = class { // body: this._mapToJson( // this._host._view._graph._modelNodeName2State // ) - body: JSON.stringify({ - "onnx_file_path" : null, - "node_states": this._mapToJson(this._host._view._graph._modelNodeName2State), - }) + body: JSON.stringify( + this._mapToJson(this._host._view._graph._modelNodeName2State) + ) }).then(function (response) { return response.text(); }).then(function (text) {