support scrolling to the last page position when updating model

1123
ZhangGe6 3 years ago
parent e5d60110b1
commit 3dc195bf53

@ -1,5 +1,7 @@
# TODO # TODO
- [ ] **ensure the model is fully loaded before modify() is called.**
- otherwise `NameError: name 'onnx_modifier' is not defined` error will be invoked.
- [ ] support desktop application. - [ ] support desktop application.
- [x] Windows - [x] Windows
- [ ] Linux - [ ] Linux

@ -1,5 +1,10 @@
# onnx-modifier update log # onnx-modifier update log
## 20220921
- add argparse module for custom config
- fix model parsing issue when the model is load in a drag-and-drop way
- support editing batch size
## 20220813 ## 20220813
- support adding model input/output node. [issue 7](https://github.com/ZhangGe6/onnx-modifier/issues/7), [issue 8](https://github.com/ZhangGe6/onnx-modifier/issues/8), [issue 13](https://github.com/ZhangGe6/onnx-modifier/issues/13) - support adding model input/output node. [issue 7](https://github.com/ZhangGe6/onnx-modifier/issues/7), [issue 8](https://github.com/ZhangGe6/onnx-modifier/issues/8), [issue 13](https://github.com/ZhangGe6/onnx-modifier/issues/13)

@ -4,6 +4,7 @@
# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model # https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model
import os import os
import time
import copy import copy
import struct import struct
import numpy as np import numpy as np
@ -203,7 +204,7 @@ class onnxModifier:
# print(modify_info['node_renamed_io']) # print(modify_info['node_renamed_io'])
# print(modify_info['node_changed_attr']) # print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info']) # print(modify_info['added_node_info'])
# print(modify_info['added_outputs']) # print(modify_info['added_outputs'])
self.change_batch_size(modify_info['rebatch_info']) self.change_batch_size(modify_info['rebatch_info'])
self.remove_node_by_node_states(modify_info['node_states']) self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io']) self.modify_node_io_name(modify_info['node_renamed_io'])
@ -245,7 +246,8 @@ class onnxModifier:
print(out.shape) print(out.shape)
if __name__ == "__main__": if __name__ == "__main__":
model_path = "C:\\Users\\ZhangGe\\Desktop\\best.onnx" model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path) onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic(): def explore_basic():

@ -1095,6 +1095,7 @@ onnx.Tensor = class {
context.limit = 10000; context.limit = 10000;
const value = this._decode(context, 0); const value = this._decode(context, 0);
// console.log(value) // console.log(value)
// console.log(onnx.Tensor._stringify(value, '', ' '))
return onnx.Tensor._stringify(value, '', ' '); return onnx.Tensor._stringify(value, '', ' ');
} }
@ -1273,7 +1274,7 @@ onnx.Tensor = class {
} }
static _stringify(value, indentation, indent) { static _stringify(value, indentation, indent) {
// console.log(value, Array.isArray(value)) // ..., false // console.log(value, Array.isArray(value))
if (Array.isArray(value)) { if (Array.isArray(value)) {
const result = []; const result = [];
result.push(indentation + '['); result.push(indentation + '[');

@ -888,7 +888,7 @@ sidebar.ArgumentView = class {
let name = this._argument.name || ''; let name = this._argument.name || '';
this._hasId = name ? true : false; this._hasId = name ? true : false;
this._hasKind = initializer && initializer.kind ? true : false; this._hasKind = initializer && initializer.kind ? true : false;
// console.log(this._hasId, this._hasKind, type) // true true ... // console.log(name, this._hasId, this._hasKind, type)
if (this._hasId || (!this._hasKind && !type)) { if (this._hasId || (!this._hasKind && !type)) {
this._hasId = true; this._hasId = true;
@ -964,6 +964,7 @@ sidebar.ArgumentView = class {
if (type && (this._hasId || this._hasKind)) { if (type && (this._hasId || this._hasKind)) {
const typeLine = this._host.document.createElement('div'); const typeLine = this._host.document.createElement('div');
typeLine.className = 'sidebar-view-item-value-line-border'; typeLine.className = 'sidebar-view-item-value-line-border';
// console.log(type, type.split('<').join('&lt;').split('>').join('&gt;'))
typeLine.innerHTML = 'type: <code><b>' + type.split('<').join('&lt;').split('>').join('&gt;') + '</b></code>'; typeLine.innerHTML = 'type: <code><b>' + type.split('<').join('&lt;').split('>').join('&gt;') + '</b></code>';
this._element.appendChild(typeLine); this._element.appendChild(typeLine);
} }

@ -25,6 +25,9 @@ view.View = class {
direction: 'vertical', direction: 'vertical',
mousewheel: 'scroll' mousewheel: 'scroll'
}; };
this.lastScrollLeft = 0;
this.lastScrollTop = 0;
this._zoom = 1;
this._host.initialize(this).then(() => { this._host.initialize(this).then(() => {
this._model = null; this._model = null;
this._graphs = []; this._graphs = [];
@ -481,6 +484,7 @@ view.View = class {
} }
this.lastViewGraph = this._graph; this.lastViewGraph = this._graph;
const graph = this.activeGraph; const graph = this.activeGraph;
// console.log(graph.nodes) // console.log(graph.nodes)
return this._timeout(100).then(() => { return this._timeout(100).then(() => {
@ -555,7 +559,7 @@ view.View = class {
return Promise.resolve(); return Promise.resolve();
} }
else { else {
this._zoom = 1; // this._zoom = 1;
const groups = graph.groups; const groups = graph.groups;
const nodes = graph.nodes; const nodes = graph.nodes;
@ -583,6 +587,10 @@ view.View = class {
viewGraph._addedOutputs = this.lastViewGraph._addedOutputs; viewGraph._addedOutputs = this.lastViewGraph._addedOutputs;
// console.log(viewGraph._renameMap); // console.log(viewGraph._renameMap);
// console.log(viewGraph._modelNodeName2State) // console.log(viewGraph._modelNodeName2State)
const container = this._getElementById('graph');
this.lastScrollLeft = container.scrollLeft;
this.lastScrollTop = container.scrollTop;
} }
viewGraph.add(graph); viewGraph.add(graph);
@ -601,13 +609,12 @@ view.View = class {
viewGraph.build(this._host.document, origin); viewGraph.build(this._host.document, origin);
this._zoom = 1; // this._zoom = 1;
return this._timeout(20).then(() => { return this._timeout(20).then(() => {
viewGraph.update(); viewGraph.update();
// 让画面可拖动/缩放 =======>
const elements = Array.from(canvas.getElementsByClassName('graph-input') || []); const elements = Array.from(canvas.getElementsByClassName('graph-input') || []);
if (elements.length === 0) { if (elements.length === 0) {
const nodeElements = Array.from(canvas.getElementsByClassName('graph-node') || []); const nodeElements = Array.from(canvas.getElementsByClassName('graph-node') || []);
@ -631,40 +638,46 @@ view.View = class {
canvas.setAttribute('width', width); canvas.setAttribute('width', width);
canvas.setAttribute('height', height); canvas.setAttribute('height', height);
this._zoom = 1;
this._updateZoom(this._zoom);
const container = this._getElementById('graph'); const container = this._getElementById('graph');
if (elements && elements.length > 0) { // console.log(this.lastScrollLeft, this.lastScrollTop, this._zoom)
// Center view based on input elements if (this.lastScrollLeft != 0 || this.lastScrollTop != 0 || this._zoom != 1) {
const xs = []; // console.log("scrolling")
const ys = []; this._updateZoom(this._zoom);
for (let i = 0; i < elements.length; i++) { container.scrollTo({ left: this.lastScrollLeft, top: this.lastScrollTop, behavior: 'auto' });
const element = elements[i]; }
const rect = element.getBoundingClientRect(); else {
xs.push(rect.left + (rect.width / 2)); this._zoom = 1;
ys.push(rect.top + (rect.height / 2)); this._updateZoom(this._zoom);
if (elements && elements.length > 0) {
// Center view based on input elements
const xs = [];
const ys = [];
for (let i = 0; i < elements.length; i++) {
const element = elements[i];
const rect = element.getBoundingClientRect();
xs.push(rect.left + (rect.width / 2));
ys.push(rect.top + (rect.height / 2));
}
let x = xs[0];
const y = ys[0];
if (ys.every(y => y === ys[0])) {
x = xs.reduce((a, b) => a + b, 0) / xs.length;
}
const graphRect = container.getBoundingClientRect();
const left = (container.scrollLeft + x - graphRect.left) - (graphRect.width / 2);
const top = (container.scrollTop + y - graphRect.top) - (graphRect.height / 2);
container.scrollTo({ left: left, top: top, behavior: 'auto' });
} }
let x = xs[0]; else {
const y = ys[0]; const canvasRect = canvas.getBoundingClientRect();
if (ys.every(y => y === ys[0])) { const graphRect = container.getBoundingClientRect();
x = xs.reduce((a, b) => a + b, 0) / xs.length; const left = (container.scrollLeft + (canvasRect.width / 2) - graphRect.left) - (graphRect.width / 2);
const top = (container.scrollTop + (canvasRect.height / 2) - graphRect.top) - (graphRect.height / 2);
container.scrollTo({ left: left, top: top, behavior: 'auto' });
} }
const graphRect = container.getBoundingClientRect();
const left = (container.scrollLeft + x - graphRect.left) - (graphRect.width / 2);
const top = (container.scrollTop + y - graphRect.top) - (graphRect.height / 2);
container.scrollTo({ left: left, top: top, behavior: 'auto' });
}
else {
const canvasRect = canvas.getBoundingClientRect();
const graphRect = container.getBoundingClientRect();
const left = (container.scrollLeft + (canvasRect.width / 2) - graphRect.left) - (graphRect.width / 2);
const top = (container.scrollTop + (canvasRect.height / 2) - graphRect.top) - (graphRect.height / 2);
container.scrollTo({ left: left, top: top, behavior: 'auto' });
} }
// <======= 让画面可拖动/缩放
this._graph = viewGraph; this._graph = viewGraph;
return; return;
}); });
@ -1021,6 +1034,11 @@ view.View = class {
} }
} }
reloadLastLocation() {
const container = this._getElementById('graph');
}
}; };
view.Graph = class extends grapher.Graph { view.Graph = class extends grapher.Graph {
@ -1308,10 +1326,16 @@ view.Graph = class extends grapher.Graph {
this._renameMap = new Map(); this._renameMap = new Map();
// clear custom added nodes // clear custom added nodes
this._addedNode = new Map() this._addedNode = new Map();
this.view._graphs[0].reset_custom_added_node() this.view._graphs[0].reset_custom_added_node();
this._addedOutputs = [] this._addedOutputs = [];
this.view._graphs[0].reset_custom_added_outputs() this.view._graphs[0].reset_custom_added_outputs();
// reset load location
var container = this.view._getElementById('graph');
container.scrollLeft = 0;
container.scrollTop = 0;
this.view._zoom = 1;
} }
recordRenameInfo(modelNodeName, src_name, dst_name) { recordRenameInfo(modelNodeName, src_name, dst_name) {

Loading…
Cancel
Save