good job. All encountered problems have been solved

1123
ZhangGe6 4 years ago
parent 4ff46e6acf
commit 228fc77d6d

@ -23,6 +23,7 @@ def modify_and_download_model():
node_states = json.loads(request.get_json())
# print(node_states)
onnx_modifier.reload() # allow for downloading for multiple times
onnx_modifier.remove_node_by_node_states(node_states)
onnx_modifier.check_and_save_model()

@ -2,14 +2,30 @@
# https://github.com/saurabh-shandilya/onnx-utils
# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model
import io
import os
import copy
import onnx
class onnxModifier:
def __init__(self, model_name, model_proto):
self.model_name = model_name
self.model_proto = model_proto
self.model_proto_backup = model_proto
@classmethod
def from_model_path(cls, model_path):
model_name = os.path.basename(model_path)
model_proto = onnx.load(model_path)
return cls(model_name, model_proto)
@classmethod
def from_name_stream(cls, name, stream):
# https://leimao.github.io/blog/ONNX-IO-Stream/
stream.seek(0)
model_proto = onnx.load_model(stream, onnx.ModelProto)
return cls(name, model_proto)
def reload(self):
self.model_proto = copy.deepcopy(self.model_proto_backup)
self.graph = self.model_proto.graph
self.initializer = self.model_proto.graph.initializer
@ -34,20 +50,7 @@ class onnxModifier:
self.initilizer_name2module = dict()
for initializer in self.initializer:
self.initilizer_name2module[initializer.name] = initializer
@classmethod
def from_model_path(cls, model_path):
model_name = os.path.basename(model_path)
model_proto = onnx.load(model_path)
return cls(model_name, model_proto)
@classmethod
def from_name_stream(cls, name, stream):
# https://leimao.github.io/blog/ONNX-IO-Stream/
stream.seek(0)
model_proto = onnx.load_model(stream, onnx.ModelProto)
return cls(name, model_proto)
def remove_node_by_name(self, node_name):
# remove node in graph
self.graph.node.remove(self.node_name2module[node_name])

@ -201,8 +201,8 @@ host.BrowserHost = class {
click: () => this._about()
});
const refreshButton = this.document.getElementById('refresh-graph');
refreshButton.addEventListener('click', () => {
const previewButton = this.document.getElementById('preview-graph');
previewButton.addEventListener('click', () => {
this._view._updateGraph();
})

@ -46,7 +46,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI"
.toolbar-back-button:hover { background: #000000; border-color: #000000; }
.toolbar-name-button { float: left; background: rgba(255, 255, 255, 0.95); border-top-right-radius: 6px; border-bottom-right-radius: 6px; border: 1px solid #777; color: #777; border-left: 1px; border-left-color: #ffffff; margin: 2px 0 2px 0; padding: 0 12px 0 6px; cursor: pointer; width: auto; height: 20px; font-size: 11px; line-height: 0; transition: 0.1s; }
.toolbar-name-button:hover { color: #000000; }
.graph-op-button-refresh {
.graph-op-button-preview {
cursor: pointer;
background-color: white;
border: 1px solid grey;
@ -58,7 +58,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI"
left: 2px;
top: 30px;
}
.graph-op-button-refresh:active { background: #e7e7e7; }
.graph-op-button-preview:active { background: #e7e7e7; }
.graph-op-button-reset {
cursor: pointer;
background-color: white;
@ -251,7 +251,7 @@ button { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI"
</svg>
</button>
<button id="refresh-graph" class="graph-op-button-refresh">Refresh</button>
<button id="preview-graph" class="graph-op-button-preview">Preview</button>
<button id="reset-graph" class="graph-op-button-reset">Reset</button>
<button id="download-graph" class="graph-op-button-download">Download</button>

Loading…
Cancel
Save