From 228fc77d6d43aa180fda6d9460953890a49aacfd Mon Sep 17 00:00:00 2001 From: ZhangGe6 Date: Thu, 28 Apr 2022 23:22:22 +0800 Subject: [PATCH] good job. All encountered problems have been solved --- app.py | 1 + onnx_modifier.py | 35 +++++++++++++++++++---------------- static/index.js | 4 ++-- templates/index.html | 6 +++--- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/app.py b/app.py index df532c5..a6c82ce 100644 --- a/app.py +++ b/app.py @@ -23,6 +23,7 @@ def modify_and_download_model(): node_states = json.loads(request.get_json()) # print(node_states) + onnx_modifier.reload() # allow for downloading for multiple times onnx_modifier.remove_node_by_node_states(node_states) onnx_modifier.check_and_save_model() diff --git a/onnx_modifier.py b/onnx_modifier.py index 56417a7..7eee1ce 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -2,14 +2,30 @@ # https://github.com/saurabh-shandilya/onnx-utils # https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model -import io import os +import copy import onnx class onnxModifier: def __init__(self, model_name, model_proto): self.model_name = model_name - self.model_proto = model_proto + self.model_proto_backup = model_proto + + @classmethod + def from_model_path(cls, model_path): + model_name = os.path.basename(model_path) + model_proto = onnx.load(model_path) + return cls(model_name, model_proto) + + @classmethod + def from_name_stream(cls, name, stream): + # https://leimao.github.io/blog/ONNX-IO-Stream/ + stream.seek(0) + model_proto = onnx.load_model(stream, onnx.ModelProto) + return cls(name, model_proto) + + def reload(self): + self.model_proto = copy.deepcopy(self.model_proto_backup) self.graph = self.model_proto.graph self.initializer = self.model_proto.graph.initializer @@ -34,20 +50,7 @@ class onnxModifier: self.initilizer_name2module = dict() for initializer in self.initializer: self.initilizer_name2module[initializer.name] = initializer - - @classmethod - def from_model_path(cls, model_path): - model_name = os.path.basename(model_path) - model_proto = onnx.load(model_path) - return cls(model_name, model_proto) - - @classmethod - def from_name_stream(cls, name, stream): - # https://leimao.github.io/blog/ONNX-IO-Stream/ - stream.seek(0) - model_proto = onnx.load_model(stream, onnx.ModelProto) - return cls(name, model_proto) - + def remove_node_by_name(self, node_name): # remove node in graph self.graph.node.remove(self.node_name2module[node_name]) diff --git a/static/index.js b/static/index.js index 96551f7..b50f1ef 100644 --- a/static/index.js +++ b/static/index.js @@ -201,8 +201,8 @@ host.BrowserHost = class { click: () => this._about() }); - const refreshButton = this.document.getElementById('refresh-graph'); - refreshButton.addEventListener('click', () => { + const previewButton = this.document.getElementById('preview-graph'); + previewButton.addEventListener('click', () => { this._view._updateGraph(); }) diff --git a/templates/index.html b/templates/index.html index cafdb9c..2c6def3 100644 --- a/templates/index.html +++ b/templates/index.html @@ -46,7 +46,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI" .toolbar-back-button:hover { background: #000000; border-color: #000000; } .toolbar-name-button { float: left; background: rgba(255, 255, 255, 0.95); border-top-right-radius: 6px; border-bottom-right-radius: 6px; border: 1px solid #777; color: #777; border-left: 1px; border-left-color: #ffffff; margin: 2px 0 2px 0; padding: 0 12px 0 6px; cursor: pointer; width: auto; height: 20px; font-size: 11px; line-height: 0; transition: 0.1s; } .toolbar-name-button:hover { color: #000000; } -.graph-op-button-refresh { +.graph-op-button-preview { cursor: pointer; background-color: white; border: 1px solid grey; @@ -58,7 +58,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI" left: 2px; top: 30px; } -.graph-op-button-refresh:active { background: #e7e7e7; } +.graph-op-button-preview:active { background: #e7e7e7; } .graph-op-button-reset { cursor: pointer; background-color: white; @@ -251,7 +251,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI" - +