diff --git a/docs/onnx_modifier_logo.png b/docs/onnx_modifier_logo.png deleted file mode 100644 index 73e8c90..0000000 Binary files a/docs/onnx_modifier_logo.png and /dev/null differ diff --git a/static/onnx.js b/static/onnx.js index 59ebd9d..873eb2b 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -334,7 +334,7 @@ onnx.Model = class { for (const func of model.functions || []) { context.metadata.add(new onnx.Function(context, func)); } - // var tmp = this.supported_ops + // var tmp = this.supported_nodes const graphs = [ model.graph ]; while (graphs.length > 0) { const graph = graphs.shift(); @@ -407,17 +407,17 @@ onnx.Model = class { return this._graphs; } - get supported_ops() { - console.log(this.graphMetadata); - var ops = [] + get supported_nodes() { + // console.log(this.graphMetadata); + var nodes = [] for (const domain of this.graphMetadata._metadata._map.keys()) { // console.log(domain) for (const op of this.graphMetadata._metadata._map.get(domain).keys()) { // console.log(op) - ops.push(op) + nodes.push([domain, op]) } } - return ops + return nodes } }; @@ -434,6 +434,7 @@ onnx.Graph = class { this._description = graph.doc_string || ''; context = new onnx.GraphContext(context, graph.node); + this._context = context; // model parameter assignment here! // console.log(graph) @@ -510,6 +511,15 @@ onnx.Graph = class { toString() { return 'graph(' + this.name + ')'; } + + make_empty_node (node_info) { + // type of node_info == LightNodeInfo + const schema = this._context.metadata.type(node_info.properties.get('op_type'), node_info.properties.get('domain')); + console.log(schema) + + + } + }; onnx.Parameter = class { diff --git a/static/view-grapher.js b/static/view-grapher.js index 722be6e..b2f5271 100644 --- a/static/view-grapher.js +++ b/static/view-grapher.js @@ -20,6 +20,8 @@ grapher.Graph = class { this._pathArgumentNames = new Set(); // the name of arguments which occurs in both sides of an edge this._renameMap = new Map(); + + this._addedNode = []; } get options() { diff --git a/static/view-sidebar.js b/static/view-sidebar.js index 7ef7f66..2548e1b 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -221,6 +221,8 @@ sidebar.NodeSidebar = class { this.add_separator(this._elements, 'sidebar-view-separator'); this._addHeader('Add children node'); this._addDropdownSelector('AddChildrenNode'); + this.add_span() + this._addButton('Add Node'); } @@ -341,17 +343,44 @@ sidebar.NodeSidebar = class { this._host._view._graph.reset_node(this._modelNodeName) }); } - } - _addDropdownSelector() { - const selectorElement = this._host.document.createElement('SELECT'); - this._elements.push(selectorElement); - // console.log(this._host._view._model.supported_ops) - for (const op of this._host._view._model.supported_ops) { - var option = new Option(op, op); - selectorElement.appendChild(option); + if (title === 'Add Node') { + buttonElement.addEventListener('click', () => { + // console.log(this.add_op_type) + this._host._view._graph.add_node(this.add_op_domain, this.add_op_type) + }); } + } + + _addDropdownSelector(title) { + if (title === 'AddChildrenNode') { + const selectorElement = this._host.document.createElement('SELECT'); + selectorElement.setAttribute('id', 'sidebar-AddChildrenNode'); + this._elements.push(selectorElement); + // console.log(this._host._view._model.supported_nodes) + for (const node of this._host._view._model.supported_nodes) { + // node: [domain, op] + // console.log(node) + // console.log(node[0]) + // console.log(node[1]) + + var option = new Option(node[1], node[0] + ':' + node[1]); + // console.log(option) + selectorElement.appendChild(option); + } + + var selected_val = selectorElement.options[selectorElement.selectedIndex].value + this.add_op_domain = selected_val.split(':')[0] + this.add_op_type = selected_val.split(':')[1] + // console.log(selectorElement.options[selectorElement.selectedIndex].text) + // console.log(selectorElement.options[selectorElement.selectedIndex].value) + selectorElement.addEventListener('change', () => { + var selected_val = selectorElement.options[selectorElement.selectedIndex].value + this.add_op_domain = selected_val.split(':')[0] + this.add_op_type = selected_val.split(':')[1] + }); + } } toggleInput(name) { diff --git a/static/view.js b/static/view.js index 72258fe..04b7980 100644 --- a/static/view.js +++ b/static/view.js @@ -457,7 +457,23 @@ view.View = class { } get activeGraph() { - return Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; + // return Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; + var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; + // console.log(this._addedNode) + // console.log(active_graph) + + if (this.lastViewGraph) { + // console.log(this.lastViewGraph._addedNode) + for (const node of this.lastViewGraph._addedNode) { + // console.log(node) + var empty_node = active_graph.make_empty_node(node) + + + + } + } + + return active_graph } _updateGraph(model, graphs) { @@ -468,9 +484,10 @@ view.View = class { this._model = model; this._graphs = graphs; } + this.lastViewGraph = this._graph; + const graph = this.activeGraph; - - this.lastViewGraph = this._graph; + // console.log("_updateGraph is called"); return this._timeout(100).then(() => { if (graph && graph != lastGraphs[0]) { @@ -568,6 +585,7 @@ view.View = class { // console.log('node state of lastViewGraph is loaded') viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State; viewGraph._renameMap = this.lastViewGraph._renameMap; + viewGraph._addedNode = this.lastViewGraph._addedNode; // console.log(viewGraph._renameMap); // console.log(viewGraph._modelNodeName2State) } @@ -850,6 +868,9 @@ view.Graph = class extends grapher.Graph { this.model = model; this._arguments = new Map(); this._nodeKey = 0; + + // the node key of custom added node + this._add_nodeKey = 0; } createNode(node) { @@ -1010,6 +1031,12 @@ view.Graph = class extends grapher.Graph { } } + + // custom added node + for (const node of this._addedNode) { + + } + for (const output of graph.outputs) { const viewOutput = this.createOutput(output); for (const argument of output.arguments) { @@ -1072,6 +1099,29 @@ view.Graph = class extends grapher.Graph { this._renameMap.get(modelNodeName).set(src_name, dst_name); } + add_node(op_domain, op_type) { + // node_name: the added op name + // parent_node_name: parent modelNodeName + // console.log(node_name) + + var node_id = (this._add_nodeKey++).toString(); // in case input (onnx) node has no name + var modelNodeName = 'custom_added_' + op_type + node_id + + // console.log(op_type) + // console.log(modelNodeName) + + var properties = new Map() + properties.set('domain', op_domain) + properties.set('op_type', op_type) + properties.set('name', modelNodeName) + + // console.log(properties) + + this._addedNode.push(new view.LightNodeInfo(properties)) + console.log(this._addedNode) + + } + build(document, origin) { for (const argument of this._arguments.values()) { @@ -1233,6 +1283,7 @@ view.Node = class extends grapher.Node { } }; + view.Input = class extends grapher.Node { constructor(context, value) { @@ -1291,6 +1342,15 @@ view.Output = class extends grapher.Node { } }; +view.LightNodeInfo = class { + constructor(properties, attributes, inputs, outputs) { + this.properties = properties + this.attributes = attributes || [] + this.inputs = inputs || [] + this.outputs = outputs || [] + } +} + view.Argument = class { constructor(context, argument) {