remove output node locally

1123
ZhangGe6 4 years ago
parent ebc280dc92
commit 02837fd309

4
.gitignore vendored

@ -0,0 +1,4 @@
__pycache__/
gym/
*.onnx

@ -1,6 +1,6 @@
from flask import Flask, render_template, request from flask import Flask, render_template, request
import onnx import json
import onnxruntime as rt from onnx_modifier import onnxModifier
app = Flask(__name__) app = Flask(__name__)
@app.route('/') @app.route('/')
@ -11,27 +11,20 @@ def index():
def return_file(): def return_file():
# https://blog.miguelgrinberg.com/post/handling-file-uploads-with-flask # https://blog.miguelgrinberg.com/post/handling-file-uploads-with-flask
onnx_file = request.files['file'] onnx_file = request.files['file']
onnx_file_name = onnx_file.filename
onnx_file_stream = onnx_file.stream
global onnx_modifier
# onnx_file_stream.seek(0) onnx_modifier = onnxModifier.from_name_stream(onnx_file.filename, onnx_file.stream)
# model_proto_from_stream = onnx.load_model(onnx_file_stream, onnx.ModelProto)
# print(model_proto_from_stream.graph)
# model_proto_bytes = onnx._serialize(model_proto_from_stream)
# inference_session = rt.InferenceSession(model_proto_bytes)
# onnx.save_model(model_proto_from_binary_stream, onnx_file.filename)
return 'OK', 200 return 'OK', 200
@app.route('/download', methods=['POST']) @app.route('/download', methods=['POST'])
def modify_and_download_model(): def modify_and_download_model():
# modelNodeStates = request.get_json() node_states = json.loads(request.get_json())
# print(modelNodeStates)
# print(modelNodeStates)
onnx_modifier.remove_node_by_node_states(node_states)
onnx_modifier.check_and_save_model()
return 'OK', 200 return 'OK', 200

@ -1,7 +1,8 @@
# https://leimao.github.io/blog/ONNX-Python-API/ # https://leimao.github.io/blog/ONNX-Python-API/
# https://github.com/saurabh-shandilya/onnx-utils # https://github.com/saurabh-shandilya/onnx-utils
import io
import os import os
from platform import node
import onnx import onnx
class onnxModifier: class onnxModifier:
@ -15,13 +16,20 @@ class onnxModifier:
def gen_node_name2module_map(self): def gen_node_name2module_map(self):
self.node_name2module = dict() self.node_name2module = dict()
node_idx = 0 node_idx = 0
for node in self.graph.input:
node_idx += 1
self.node_name2module[node.name] = node
for node in self.graph.node: for node in self.graph.node:
if node.name == '': if node.name == '':
node.name = str(node.op_type) + str(node_idx) node.name = str(node.op_type) + str(node_idx)
node_idx += 1 node_idx += 1
self.node_name2module[node.name] = node self.node_name2module[node.name] = node
for node in self.graph.output:
node_idx += 1
self.node_name2module[node.name] = node
self.graph_output_names = [node.name for node in self.graph.output]
# print(self.node_name2module.keys()) # print(self.node_name2module.keys())
@classmethod @classmethod
@ -33,6 +41,7 @@ class onnxModifier:
@classmethod @classmethod
def from_name_stream(cls, name, stream): def from_name_stream(cls, name, stream):
# https://leimao.github.io/blog/ONNX-IO-Stream/ # https://leimao.github.io/blog/ONNX-IO-Stream/
stream.seek(0) stream.seek(0)
model_proto = onnx.load_model(stream, onnx.ModelProto) model_proto = onnx.load_model(stream, onnx.ModelProto)
return cls(name, model_proto) return cls(name, model_proto)
@ -40,18 +49,45 @@ class onnxModifier:
def remove_node_by_name(self, node_name): def remove_node_by_name(self, node_name):
self.graph.node.remove(self.node_name2module[node_name]) self.graph.node.remove(self.node_name2module[node_name])
def check_and_save_model(self, save_dir='./'): def remove_output_by_name(self, node_name):
self.graph.output.remove(self.node_name2module[node_name])
def remove_node_by_node_states(self, node_states):
for node_name, node_state in node_states.items():
if node_state == 'Deleted':
if node_name in self.graph_output_names:
self.remove_output_by_name(node_name)
else:
self.remove_node_by_name(node_name)
def check_and_save_model(self, save_dir='./res_onnx'):
save_path = os.path.join(save_dir, 'modified_' + self.model_name) save_path = os.path.join(save_dir, 'modified_' + self.model_name)
onnx.checker.check_model(self.model_proto) onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path) onnx.save(self.model_proto, save_path)
def inference(self):
# model_proto_bytes = onnx._serialize(model_proto_from_stream)
# inference_session = rt.InferenceSession(model_proto_bytes)
pass
if __name__ == "__main__": if __name__ == "__main__":
model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx" model_path = "C:\\Users\\ZhangGe\\Desktop\\squeezenet1.0-12-int8.onnx"
onnx_modifer = onnxModifier.from_model_path(model_path) onnx_modifier = onnxModifier.from_model_path(model_path)
onnx_modifer.remove_node_by_name('Softmax_nc_rename_64') onnx_modifier.remove_node_by_name('Softmax_nc_rename_64')
onnx_modifer.check_and_save_model() onnx_modifier.remove_output_by_name('softmaxout_1')
# onnx_modifier.graph.output.remove(onnx_modifier.node_name2module['softmaxout_1'])
onnx_modifier.check_and_save_model()
# print(onnx_modifier.graph.input)
# print(onnx_modifier.graph.output)
# print(onnx_modifier.node_name2module['Softmax_nc_rename_64'])
# print(onnx_modifier.node_name2module['softmaxout_1'])
# onnx_modifier.remove_node_by_name('softmaxout_1')
# for node in onnx_modifier.graph.output:
# print(node.name)

@ -322,10 +322,9 @@ sidebar.NodeSidebar = class {
// body: this._mapToJson( // body: this._mapToJson(
// this._host._view._graph._modelNodeName2State // this._host._view._graph._modelNodeName2State
// ) // )
body: JSON.stringify({ body: JSON.stringify(
"onnx_file_path" : null, this._mapToJson(this._host._view._graph._modelNodeName2State)
"node_states": this._mapToJson(this._host._view._graph._modelNodeName2State), )
})
}).then(function (response) { }).then(function (response) {
return response.text(); return response.text();
}).then(function (text) { }).then(function (text) {

Loading…
Cancel
Save