support scrolling to the last page position when updating model

1123
ZhangGe6 3 years ago
parent e5d60110b1
commit 3dc195bf53

@ -1,5 +1,7 @@
# TODO
- [ ] **ensure the model is fully loaded before modify() is called.**
- otherwise `NameError: name 'onnx_modifier' is not defined` error will be invoked.
- [ ] support desktop application.
- [x] Windows
- [ ] Linux

@ -1,5 +1,10 @@
# onnx-modifier update log
## 20220921
- add argparse module for custom config
- fix model parsing issue when the model is load in a drag-and-drop way
- support editing batch size
## 20220813
- support adding model input/output node. [issue 7](https://github.com/ZhangGe6/onnx-modifier/issues/7), [issue 8](https://github.com/ZhangGe6/onnx-modifier/issues/8), [issue 13](https://github.com/ZhangGe6/onnx-modifier/issues/13)

@ -4,6 +4,7 @@
# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model
import os
import time
import copy
import struct
import numpy as np
@ -245,7 +246,8 @@ class onnxModifier:
print(out.shape)
if __name__ == "__main__":
model_path = "C:\\Users\\ZhangGe\\Desktop\\best.onnx"
model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx"
# model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx"
onnx_modifier = onnxModifier.from_model_path(model_path)
def explore_basic():

@ -1095,6 +1095,7 @@ onnx.Tensor = class {
context.limit = 10000;
const value = this._decode(context, 0);
// console.log(value)
// console.log(onnx.Tensor._stringify(value, '', ' '))
return onnx.Tensor._stringify(value, '', ' ');
}
@ -1273,7 +1274,7 @@ onnx.Tensor = class {
}
static _stringify(value, indentation, indent) {
// console.log(value, Array.isArray(value)) // ..., false
// console.log(value, Array.isArray(value))
if (Array.isArray(value)) {
const result = [];
result.push(indentation + '[');

@ -888,7 +888,7 @@ sidebar.ArgumentView = class {
let name = this._argument.name || '';
this._hasId = name ? true : false;
this._hasKind = initializer && initializer.kind ? true : false;
// console.log(this._hasId, this._hasKind, type) // true true ...
// console.log(name, this._hasId, this._hasKind, type)
if (this._hasId || (!this._hasKind && !type)) {
this._hasId = true;
@ -964,6 +964,7 @@ sidebar.ArgumentView = class {
if (type && (this._hasId || this._hasKind)) {
const typeLine = this._host.document.createElement('div');
typeLine.className = 'sidebar-view-item-value-line-border';
// console.log(type, type.split('<').join('&lt;').split('>').join('&gt;'))
typeLine.innerHTML = 'type: <code><b>' + type.split('<').join('&lt;').split('>').join('&gt;') + '</b></code>';
this._element.appendChild(typeLine);
}

@ -25,6 +25,9 @@ view.View = class {
direction: 'vertical',
mousewheel: 'scroll'
};
this.lastScrollLeft = 0;
this.lastScrollTop = 0;
this._zoom = 1;
this._host.initialize(this).then(() => {
this._model = null;
this._graphs = [];
@ -481,6 +484,7 @@ view.View = class {
}
this.lastViewGraph = this._graph;
const graph = this.activeGraph;
// console.log(graph.nodes)
return this._timeout(100).then(() => {
@ -555,7 +559,7 @@ view.View = class {
return Promise.resolve();
}
else {
this._zoom = 1;
// this._zoom = 1;
const groups = graph.groups;
const nodes = graph.nodes;
@ -583,6 +587,10 @@ view.View = class {
viewGraph._addedOutputs = this.lastViewGraph._addedOutputs;
// console.log(viewGraph._renameMap);
// console.log(viewGraph._modelNodeName2State)
const container = this._getElementById('graph');
this.lastScrollLeft = container.scrollLeft;
this.lastScrollTop = container.scrollTop;
}
viewGraph.add(graph);
@ -601,13 +609,12 @@ view.View = class {
viewGraph.build(this._host.document, origin);
this._zoom = 1;
// this._zoom = 1;
return this._timeout(20).then(() => {
viewGraph.update();
// 让画面可拖动/缩放 =======>
const elements = Array.from(canvas.getElementsByClassName('graph-input') || []);
if (elements.length === 0) {
const nodeElements = Array.from(canvas.getElementsByClassName('graph-node') || []);
@ -631,40 +638,46 @@ view.View = class {
canvas.setAttribute('width', width);
canvas.setAttribute('height', height);
this._zoom = 1;
this._updateZoom(this._zoom);
const container = this._getElementById('graph');
if (elements && elements.length > 0) {
// Center view based on input elements
const xs = [];
const ys = [];
for (let i = 0; i < elements.length; i++) {
const element = elements[i];
const rect = element.getBoundingClientRect();
xs.push(rect.left + (rect.width / 2));
ys.push(rect.top + (rect.height / 2));
}
let x = xs[0];
const y = ys[0];
if (ys.every(y => y === ys[0])) {
x = xs.reduce((a, b) => a + b, 0) / xs.length;
}
const graphRect = container.getBoundingClientRect();
const left = (container.scrollLeft + x - graphRect.left) - (graphRect.width / 2);
const top = (container.scrollTop + y - graphRect.top) - (graphRect.height / 2);
container.scrollTo({ left: left, top: top, behavior: 'auto' });
// console.log(this.lastScrollLeft, this.lastScrollTop, this._zoom)
if (this.lastScrollLeft != 0 || this.lastScrollTop != 0 || this._zoom != 1) {
// console.log("scrolling")
this._updateZoom(this._zoom);
container.scrollTo({ left: this.lastScrollLeft, top: this.lastScrollTop, behavior: 'auto' });
}
else {
const canvasRect = canvas.getBoundingClientRect();
const graphRect = container.getBoundingClientRect();
const left = (container.scrollLeft + (canvasRect.width / 2) - graphRect.left) - (graphRect.width / 2);
const top = (container.scrollTop + (canvasRect.height / 2) - graphRect.top) - (graphRect.height / 2);
container.scrollTo({ left: left, top: top, behavior: 'auto' });
this._zoom = 1;
this._updateZoom(this._zoom);
if (elements && elements.length > 0) {
// Center view based on input elements
const xs = [];
const ys = [];
for (let i = 0; i < elements.length; i++) {
const element = elements[i];
const rect = element.getBoundingClientRect();
xs.push(rect.left + (rect.width / 2));
ys.push(rect.top + (rect.height / 2));
}
let x = xs[0];
const y = ys[0];
if (ys.every(y => y === ys[0])) {
x = xs.reduce((a, b) => a + b, 0) / xs.length;
}
const graphRect = container.getBoundingClientRect();
const left = (container.scrollLeft + x - graphRect.left) - (graphRect.width / 2);
const top = (container.scrollTop + y - graphRect.top) - (graphRect.height / 2);
container.scrollTo({ left: left, top: top, behavior: 'auto' });
}
else {
const canvasRect = canvas.getBoundingClientRect();
const graphRect = container.getBoundingClientRect();
const left = (container.scrollLeft + (canvasRect.width / 2) - graphRect.left) - (graphRect.width / 2);
const top = (container.scrollTop + (canvasRect.height / 2) - graphRect.top) - (graphRect.height / 2);
container.scrollTo({ left: left, top: top, behavior: 'auto' });
}
}
// <======= 让画面可拖动/缩放
this._graph = viewGraph;
return;
});
@ -1021,6 +1034,11 @@ view.View = class {
}
}
reloadLastLocation() {
const container = this._getElementById('graph');
}
};
view.Graph = class extends grapher.Graph {
@ -1308,10 +1326,16 @@ view.Graph = class extends grapher.Graph {
this._renameMap = new Map();
// clear custom added nodes
this._addedNode = new Map()
this.view._graphs[0].reset_custom_added_node()
this._addedOutputs = []
this.view._graphs[0].reset_custom_added_outputs()
this._addedNode = new Map();
this.view._graphs[0].reset_custom_added_node();
this._addedOutputs = [];
this.view._graphs[0].reset_custom_added_outputs();
// reset load location
var container = this.view._getElementById('graph');
container.scrollLeft = 0;
container.scrollTop = 0;
this.view._zoom = 1;
}
recordRenameInfo(modelNodeName, src_name, dst_name) {

Loading…
Cancel
Save