local onnxModifier done

1123
ZhangGe6 4 years ago
parent 890e54ae20
commit ebc280dc92

@ -1,4 +1,3 @@
import json
from flask import Flask, render_template, request
import onnx
import onnxruntime as rt
@ -8,29 +7,33 @@ app = Flask(__name__)
def index():
return render_template('index.html')
@app.route('/download', methods=['POST'])
def downloadModel():
modelNodeStates = request.get_json()
print(modelNodeStates)
return 'OK', 200
@app.route('/return_file', methods=['POST'])
def return_file():
# https://blog.miguelgrinberg.com/post/handling-file-uploads-with-flask
onnx_file = request.files['file']
# print(onnx_file.filename)
# print(onnx_file.stream)
# onnx_file.save(onnx_file.filename)
onnx_file_name = onnx_file.filename
onnx_file_stream = onnx_file.stream
# https://leimao.github.io/blog/ONNX-IO-Stream/
onnx_file_stream.seek(0)
model_proto_from_stream = onnx.load_model(onnx_file_stream, onnx.ModelProto)
model_proto_bytes = onnx._serialize(model_proto_from_stream)
inference_session = rt.InferenceSession(model_proto_bytes)
# onnx_file_stream.seek(0)
# model_proto_from_stream = onnx.load_model(onnx_file_stream, onnx.ModelProto)
# print(model_proto_from_stream.graph)
# model_proto_bytes = onnx._serialize(model_proto_from_stream)
# inference_session = rt.InferenceSession(model_proto_bytes)
# onnx.save_model(model_proto_from_binary_stream, onnx_file.filename)
return 'OK', 200
@app.route('/download', methods=['POST'])
def modify_and_download_model():
# modelNodeStates = request.get_json()
# print(modelNodeStates)
return 'OK', 200
if __name__ == '__main__':

@ -0,0 +1,59 @@
# https://leimao.github.io/blog/ONNX-Python-API/
# https://github.com/saurabh-shandilya/onnx-utils
import os
import onnx
class onnxModifier:
def __init__(self, model_name, model_proto):
self.model_name = model_name
self.model_proto = model_proto
self.graph = self.model_proto.graph
self.gen_node_name2module_map()
def gen_node_name2module_map(self):
self.node_name2module = dict()
node_idx = 0
for node in self.graph.node:
if node.name == '':
node.name = str(node.op_type) + str(node_idx)
node_idx += 1
self.node_name2module[node.name] = node
# print(self.node_name2module.keys())
@classmethod
def from_model_path(cls, model_path):
model_name = os.path.basename(model_path)
model_proto = onnx.load(model_path)
return cls(model_name, model_proto)
@classmethod
def from_name_stream(cls, name, stream):
# https://leimao.github.io/blog/ONNX-IO-Stream/
stream.seek(0)
model_proto = onnx.load_model(stream, onnx.ModelProto)
return cls(name, model_proto)
def remove_node_by_name(self, node_name):
self.graph.node.remove(self.node_name2module[node_name])
def check_and_save_model(self, save_dir='./'):
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path)
if __name__ == "__main__":
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
onnx_modifer = onnxModifier.from_model_path(model_path)
onnx_modifer.remove_node_by_name('Softmax_nc_rename_64')
onnx_modifer.check_and_save_model()
Loading…
Cancel
Save