From ebc280dc92b19a186f76d08fbcfce0a0309799b5 Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Mon, 25 Apr 2022 21:34:10 +0800 Subject: [PATCH] local onnxModifier done --- app.py | 37 ++++++++++++++++-------------- onnx_modifier.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 17 deletions(-) create mode 100644 onnx_modifier.py diff --git a/app.py b/app.py index 2b9cb82..da341ba 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,3 @@ -import json from flask import Flask, render_template, request import onnx import onnxruntime as rt @@ -8,29 +7,33 @@ app = Flask(__name__) def index(): return render_template('index.html') -@app.route('/download', methods=['POST']) -def downloadModel(): - modelNodeStates = request.get_json() - print(modelNodeStates) - return 'OK', 200 - @app.route('/return_file', methods=['POST']) def return_file(): # https://blog.miguelgrinberg.com/post/handling-file-uploads-with-flask onnx_file = request.files['file'] - # print(onnx_file.filename) - # print(onnx_file.stream) - # onnx_file.save(onnx_file.filename) - + onnx_file_name = onnx_file.filename onnx_file_stream = onnx_file.stream - - # https://leimao.github.io/blog/ONNX-IO-Stream/ - onnx_file_stream.seek(0) - model_proto_from_stream = onnx.load_model(onnx_file_stream, onnx.ModelProto) - model_proto_bytes = onnx._serialize(model_proto_from_stream) - inference_session = rt.InferenceSession(model_proto_bytes) + + + # onnx_file_stream.seek(0) + # model_proto_from_stream = onnx.load_model(onnx_file_stream, onnx.ModelProto) + # print(model_proto_from_stream.graph) + # model_proto_bytes = onnx._serialize(model_proto_from_stream) + # inference_session = rt.InferenceSession(model_proto_bytes) # onnx.save_model(model_proto_from_binary_stream, onnx_file.filename) + return 'OK', 200 + + +@app.route('/download', methods=['POST']) +def modify_and_download_model(): + # modelNodeStates = request.get_json() + # print(modelNodeStates) + + + + + return 'OK', 200 if __name__ == '__main__': diff --git a/onnx_modifier.py b/onnx_modifier.py new file mode 100644 index 0000000..2df4e44 --- /dev/null +++ b/onnx_modifier.py @@ -0,0 +1,59 @@ +# https://leimao.github.io/blog/ONNX-Python-API/ +# https://github.com/saurabh-shandilya/onnx-utils + +import os +import onnx + +class onnxModifier: + def __init__(self, model_name, model_proto): + self.model_name = model_name + self.model_proto = model_proto + self.graph = self.model_proto.graph + + self.gen_node_name2module_map() + + def gen_node_name2module_map(self): + self.node_name2module = dict() + node_idx = 0 + for node in self.graph.node: + if node.name == '': + node.name = str(node.op_type) + str(node_idx) + node_idx += 1 + + self.node_name2module[node.name] = node + + # print(self.node_name2module.keys()) + + @classmethod + def from_model_path(cls, model_path): + model_name = os.path.basename(model_path) + model_proto = onnx.load(model_path) + return cls(model_name, model_proto) + + @classmethod + def from_name_stream(cls, name, stream): + # https://leimao.github.io/blog/ONNX-IO-Stream/ + stream.seek(0) + model_proto = onnx.load_model(stream, onnx.ModelProto) + return cls(name, model_proto) + + def remove_node_by_name(self, node_name): + self.graph.node.remove(self.node_name2module[node_name]) + + def check_and_save_model(self, save_dir='./'): + save_path = os.path.join(save_dir, 'modified_' + self.model_name) + onnx.checker.check_model(self.model_proto) + onnx.save(self.model_proto, save_path) + + +if __name__ == "__main__": + model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" + onnx_modifer = onnxModifier.from_model_path(model_path) + onnx_modifer.remove_node_by_name('Softmax_nc_rename_64') + onnx_modifer.check_and_save_model() + + + + + + \ No newline at end of file