diff --git a/app.py b/app.py
index df532c5..a6c82ce 100644
--- a/app.py
+++ b/app.py
@@ -23,6 +23,7 @@ def modify_and_download_model():
node_states = json.loads(request.get_json())
# print(node_states)
+ onnx_modifier.reload() # allow for downloading for multiple times
onnx_modifier.remove_node_by_node_states(node_states)
onnx_modifier.check_and_save_model()
diff --git a/onnx_modifier.py b/onnx_modifier.py
index 56417a7..7eee1ce 100644
--- a/onnx_modifier.py
+++ b/onnx_modifier.py
@@ -2,14 +2,30 @@
# https://github.com/saurabh-shandilya/onnx-utils
# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model
-import io
import os
+import copy
import onnx
class onnxModifier:
def __init__(self, model_name, model_proto):
self.model_name = model_name
- self.model_proto = model_proto
+ self.model_proto_backup = model_proto
+
+ @classmethod
+ def from_model_path(cls, model_path):
+ model_name = os.path.basename(model_path)
+ model_proto = onnx.load(model_path)
+ return cls(model_name, model_proto)
+
+ @classmethod
+ def from_name_stream(cls, name, stream):
+ # https://leimao.github.io/blog/ONNX-IO-Stream/
+ stream.seek(0)
+ model_proto = onnx.load_model(stream, onnx.ModelProto)
+ return cls(name, model_proto)
+
+ def reload(self):
+ self.model_proto = copy.deepcopy(self.model_proto_backup)
self.graph = self.model_proto.graph
self.initializer = self.model_proto.graph.initializer
@@ -34,20 +50,7 @@ class onnxModifier:
self.initilizer_name2module = dict()
for initializer in self.initializer:
self.initilizer_name2module[initializer.name] = initializer
-
- @classmethod
- def from_model_path(cls, model_path):
- model_name = os.path.basename(model_path)
- model_proto = onnx.load(model_path)
- return cls(model_name, model_proto)
-
- @classmethod
- def from_name_stream(cls, name, stream):
- # https://leimao.github.io/blog/ONNX-IO-Stream/
- stream.seek(0)
- model_proto = onnx.load_model(stream, onnx.ModelProto)
- return cls(name, model_proto)
-
+
def remove_node_by_name(self, node_name):
# remove node in graph
self.graph.node.remove(self.node_name2module[node_name])
diff --git a/static/index.js b/static/index.js
index 96551f7..b50f1ef 100644
--- a/static/index.js
+++ b/static/index.js
@@ -201,8 +201,8 @@ host.BrowserHost = class {
click: () => this._about()
});
- const refreshButton = this.document.getElementById('refresh-graph');
- refreshButton.addEventListener('click', () => {
+ const previewButton = this.document.getElementById('preview-graph');
+ previewButton.addEventListener('click', () => {
this._view._updateGraph();
})
diff --git a/templates/index.html b/templates/index.html
index cafdb9c..2c6def3 100644
--- a/templates/index.html
+++ b/templates/index.html
@@ -46,7 +46,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI"
.toolbar-back-button:hover { background: #000000; border-color: #000000; }
.toolbar-name-button { float: left; background: rgba(255, 255, 255, 0.95); border-top-right-radius: 6px; border-bottom-right-radius: 6px; border: 1px solid #777; color: #777; border-left: 1px; border-left-color: #ffffff; margin: 2px 0 2px 0; padding: 0 12px 0 6px; cursor: pointer; width: auto; height: 20px; font-size: 11px; line-height: 0; transition: 0.1s; }
.toolbar-name-button:hover { color: #000000; }
-.graph-op-button-refresh {
+.graph-op-button-preview {
cursor: pointer;
background-color: white;
border: 1px solid grey;
@@ -58,7 +58,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI"
left: 2px;
top: 30px;
}
-.graph-op-button-refresh:active { background: #e7e7e7; }
+.graph-op-button-preview:active { background: #e7e7e7; }
.graph-op-button-reset {
cursor: pointer;
background-color: white;
@@ -251,7 +251,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI"
-
+