diff --git a/docs/todo_list.md b/docs/todo_list.md index dc198fb..2ff0948 100644 --- a/docs/todo_list.md +++ b/docs/todo_list.md @@ -1,5 +1,7 @@ # TODO +- [ ] **ensure the model is fully loaded before modify() is called.** + - otherwise `NameError: name 'onnx_modifier' is not defined` error will be invoked. - [ ] support desktop application. - [x] Windows - [ ] Linux diff --git a/docs/update_log.md b/docs/update_log.md index 5d9b360..b6815a4 100644 --- a/docs/update_log.md +++ b/docs/update_log.md @@ -1,5 +1,10 @@ # onnx-modifier update log +## 20220921 +- add argparse module for custom config +- fix model parsing issue when the model is load in a drag-and-drop way +- support editing batch size + ## 20220813 - support adding model input/output node. [issue 7](https://github.com/ZhangGe6/onnx-modifier/issues/7), [issue 8](https://github.com/ZhangGe6/onnx-modifier/issues/8), [issue 13](https://github.com/ZhangGe6/onnx-modifier/issues/13) diff --git a/onnx_modifier.py b/onnx_modifier.py index 872eca6..fefb67f 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -4,6 +4,7 @@ # https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model import os +import time import copy import struct import numpy as np @@ -203,7 +204,7 @@ class onnxModifier: # print(modify_info['node_renamed_io']) # print(modify_info['node_changed_attr']) # print(modify_info['added_node_info']) - # print(modify_info['added_outputs']) + # print(modify_info['added_outputs']) self.change_batch_size(modify_info['rebatch_info']) self.remove_node_by_node_states(modify_info['node_states']) self.modify_node_io_name(modify_info['node_renamed_io']) @@ -245,7 +246,8 @@ class onnxModifier: print(out.shape) if __name__ == "__main__": - model_path = "C:\\Users\\ZhangGe\\Desktop\\best.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) def explore_basic(): diff --git a/static/onnx.js b/static/onnx.js index 2e786d7..a3e0c60 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -1095,6 +1095,7 @@ onnx.Tensor = class { context.limit = 10000; const value = this._decode(context, 0); // console.log(value) + // console.log(onnx.Tensor._stringify(value, '', ' ')) return onnx.Tensor._stringify(value, '', ' '); } @@ -1273,7 +1274,7 @@ onnx.Tensor = class { } static _stringify(value, indentation, indent) { - // console.log(value, Array.isArray(value)) // ..., false + // console.log(value, Array.isArray(value)) if (Array.isArray(value)) { const result = []; result.push(indentation + '['); diff --git a/static/view-sidebar.js b/static/view-sidebar.js index b2d6f89..85ce65e 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -888,7 +888,7 @@ sidebar.ArgumentView = class { let name = this._argument.name || ''; this._hasId = name ? true : false; this._hasKind = initializer && initializer.kind ? true : false; - // console.log(this._hasId, this._hasKind, type) // true true ... + // console.log(name, this._hasId, this._hasKind, type) if (this._hasId || (!this._hasKind && !type)) { this._hasId = true; @@ -964,6 +964,7 @@ sidebar.ArgumentView = class { if (type && (this._hasId || this._hasKind)) { const typeLine = this._host.document.createElement('div'); typeLine.className = 'sidebar-view-item-value-line-border'; + // console.log(type, type.split('<').join('<').split('>').join('>')) typeLine.innerHTML = 'type: ' + type.split('<').join('<').split('>').join('>') + ''; this._element.appendChild(typeLine); } diff --git a/static/view.js b/static/view.js index ef1e5aa..6551369 100644 --- a/static/view.js +++ b/static/view.js @@ -25,6 +25,9 @@ view.View = class { direction: 'vertical', mousewheel: 'scroll' }; + this.lastScrollLeft = 0; + this.lastScrollTop = 0; + this._zoom = 1; this._host.initialize(this).then(() => { this._model = null; this._graphs = []; @@ -481,6 +484,7 @@ view.View = class { } this.lastViewGraph = this._graph; const graph = this.activeGraph; + // console.log(graph.nodes) return this._timeout(100).then(() => { @@ -555,7 +559,7 @@ view.View = class { return Promise.resolve(); } else { - this._zoom = 1; + // this._zoom = 1; const groups = graph.groups; const nodes = graph.nodes; @@ -583,6 +587,10 @@ view.View = class { viewGraph._addedOutputs = this.lastViewGraph._addedOutputs; // console.log(viewGraph._renameMap); // console.log(viewGraph._modelNodeName2State) + + const container = this._getElementById('graph'); + this.lastScrollLeft = container.scrollLeft; + this.lastScrollTop = container.scrollTop; } viewGraph.add(graph); @@ -601,13 +609,12 @@ view.View = class { viewGraph.build(this._host.document, origin); - this._zoom = 1; + // this._zoom = 1; return this._timeout(20).then(() => { viewGraph.update(); - // 让画面可拖动/缩放 =======> const elements = Array.from(canvas.getElementsByClassName('graph-input') || []); if (elements.length === 0) { const nodeElements = Array.from(canvas.getElementsByClassName('graph-node') || []); @@ -631,40 +638,46 @@ view.View = class { canvas.setAttribute('width', width); canvas.setAttribute('height', height); - this._zoom = 1; - this._updateZoom(this._zoom); - const container = this._getElementById('graph'); - if (elements && elements.length > 0) { - // Center view based on input elements - const xs = []; - const ys = []; - for (let i = 0; i < elements.length; i++) { - const element = elements[i]; - const rect = element.getBoundingClientRect(); - xs.push(rect.left + (rect.width / 2)); - ys.push(rect.top + (rect.height / 2)); + // console.log(this.lastScrollLeft, this.lastScrollTop, this._zoom) + if (this.lastScrollLeft != 0 || this.lastScrollTop != 0 || this._zoom != 1) { + // console.log("scrolling") + this._updateZoom(this._zoom); + container.scrollTo({ left: this.lastScrollLeft, top: this.lastScrollTop, behavior: 'auto' }); + } + else { + this._zoom = 1; + this._updateZoom(this._zoom); + + if (elements && elements.length > 0) { + // Center view based on input elements + const xs = []; + const ys = []; + for (let i = 0; i < elements.length; i++) { + const element = elements[i]; + const rect = element.getBoundingClientRect(); + xs.push(rect.left + (rect.width / 2)); + ys.push(rect.top + (rect.height / 2)); + } + let x = xs[0]; + const y = ys[0]; + if (ys.every(y => y === ys[0])) { + x = xs.reduce((a, b) => a + b, 0) / xs.length; + } + const graphRect = container.getBoundingClientRect(); + const left = (container.scrollLeft + x - graphRect.left) - (graphRect.width / 2); + const top = (container.scrollTop + y - graphRect.top) - (graphRect.height / 2); + container.scrollTo({ left: left, top: top, behavior: 'auto' }); } - let x = xs[0]; - const y = ys[0]; - if (ys.every(y => y === ys[0])) { - x = xs.reduce((a, b) => a + b, 0) / xs.length; + else { + const canvasRect = canvas.getBoundingClientRect(); + const graphRect = container.getBoundingClientRect(); + const left = (container.scrollLeft + (canvasRect.width / 2) - graphRect.left) - (graphRect.width / 2); + const top = (container.scrollTop + (canvasRect.height / 2) - graphRect.top) - (graphRect.height / 2); + container.scrollTo({ left: left, top: top, behavior: 'auto' }); } - const graphRect = container.getBoundingClientRect(); - const left = (container.scrollLeft + x - graphRect.left) - (graphRect.width / 2); - const top = (container.scrollTop + y - graphRect.top) - (graphRect.height / 2); - container.scrollTo({ left: left, top: top, behavior: 'auto' }); - } - else { - const canvasRect = canvas.getBoundingClientRect(); - const graphRect = container.getBoundingClientRect(); - const left = (container.scrollLeft + (canvasRect.width / 2) - graphRect.left) - (graphRect.width / 2); - const top = (container.scrollTop + (canvasRect.height / 2) - graphRect.top) - (graphRect.height / 2); - container.scrollTo({ left: left, top: top, behavior: 'auto' }); } - // <======= 让画面可拖动/缩放 - this._graph = viewGraph; return; }); @@ -1021,6 +1034,11 @@ view.View = class { } } + reloadLastLocation() { + const container = this._getElementById('graph'); + + + } }; view.Graph = class extends grapher.Graph { @@ -1308,10 +1326,16 @@ view.Graph = class extends grapher.Graph { this._renameMap = new Map(); // clear custom added nodes - this._addedNode = new Map() - this.view._graphs[0].reset_custom_added_node() - this._addedOutputs = [] - this.view._graphs[0].reset_custom_added_outputs() + this._addedNode = new Map(); + this.view._graphs[0].reset_custom_added_node(); + this._addedOutputs = []; + this.view._graphs[0].reset_custom_added_outputs(); + + // reset load location + var container = this.view._getElementById('graph'); + container.scrollLeft = 0; + container.scrollTop = 0; + this.view._zoom = 1; } recordRenameInfo(modelNodeName, src_name, dst_name) {