diff --git a/static/onnx.js b/static/onnx.js index 873eb2b..ee1a9d8 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -435,6 +435,9 @@ onnx.Graph = class { context = new onnx.GraphContext(context, graph.node); this._context = context; + + this._custom_add_node_io_idx = 0 + this._custom_added_node = [] // model parameter assignment here! // console.log(graph) @@ -505,19 +508,95 @@ onnx.Graph = class { } get nodes() { - return this._nodes; + // return this._nodes; + return this._nodes.concat(this._custom_added_node); + } + + reset_custom_added_node() { + this._custom_added_node = [] + this._custom_add_node_io_idx = 0 } toString() { return 'graph(' + this.name + ')'; } - make_empty_node (node_info) { + make_custom_add_node(node_info) { // type of node_info == LightNodeInfo const schema = this._context.metadata.type(node_info.properties.get('op_type'), node_info.properties.get('domain')); console.log(schema) + + // console.log(node_info.attributes) + // console.log(node_info.inputs) + // console.log(node_info.outputs) + // var max_input = schema.max_input + // var min_input = schema.max_input + var max_custom_add_input_num = Math.min(schema.max_input, 5) // set at most 5 custom_add inputs + var max_custom_add_output_num = Math.min(schema.max_output, 5) // set at most 5 custom_add inputs + + var inputs = [] + for (let i = 0; i < schema.inputs.length; ++i) { + const input = schema.inputs[i] + var arg_list = [] + if (input.list) { + for (let j = 0; j < max_custom_add_input_num; ++j) { + var arg_name = 'custom_input_' + (this.custom_add_node_io_idx++).toString() + arg_list.push(this._context.argument(arg_name)) + } + } + else { + var arg_name = 'custom_input_' + (this.custom_add_node_io_idx++).toString() + arg_list = [this._context.argument(arg_name)] + } + inputs.push(new onnx.Parameter(input.name, arg_list)); + } + var outputs = [] + for (let i = 0; i < schema.outputs.length; ++i) { + const output = schema.outputs[i] + var arg_list = [] + if (output.list) { + for (let j = 0; j < max_custom_add_output_num; ++j) { + var arg_name = 'custom_output_' + (this.custom_add_node_io_idx++).toString() + arg_list.push(this._context.argument(arg_name)) + } + } + else { + var arg_name = 'custom_output_' + (this.custom_add_node_io_idx++).toString() + arg_list = [this._context.argument(arg_name)] + } + outputs.push(new onnx.Parameter(output.name, arg_list)); + } - + // console.log(inputs) + // console.log(outputs) + + var attributes = [] + if (schema.attributes) { + for (const attr of schema.attributes) { + attributes.push( + new onnx.LightAttributeInfo( + attr.name, + attr.description, + attr.type + ) + ) + } + } + // console.log(attributes) + + var custom_add_node = new onnx.Node( + this._context, + node_info.properties.get('op_type'), + node_info.properties.get('domain'), + node_info.properties.get('name'), + schema.description, + attributes, + inputs, + outputs + ); + + // console.log(custom_add_node) + this._custom_added_node.push(custom_add_node) } }; @@ -655,9 +734,10 @@ onnx.Attribute = class { // `context` here is GraphContext constructor(context, op_type, domain, attribute) { this._name = attribute.name; - this._description = attribute.doc_string || ''; + this._description = attribute.doc_string || attribute.description || ''; this._type = null; this._value = null; + switch (attribute.type) { case onnx.AttributeType.FLOAT: this._value = attribute.f; @@ -723,7 +803,11 @@ onnx.Attribute = class { this._type = 'type[]'; break; default: - throw new onnx.Error("Unknown attribute type '" + attribute.type + "'."); + // console.log(attribute) + this._value = null; + this._type = attribute.type; + // TODO: I comment the Error message for the compatibility of onnx.Graph.make_custom_add_node. This is unsafe + // throw new onnx.Error("Unknown attribute type '" + attribute.type + "'."); } // console.log(attribute.type) // console.log(this._value) @@ -767,6 +851,15 @@ onnx.Attribute = class { } }; +onnx.LightAttributeInfo = class { + constructor(name, description, type, value) { + this.name = name; + this.description = description; + this.type = type; + this.value = value || null; + } +} + onnx.Group = class { constructor(name, groups) { @@ -1062,7 +1155,7 @@ onnx.Tensor = class { }; this._values = decode(this._values); if (!this._values) { - context.state = 'Tensor data is empty.'; + context.state = 'Tensor data is custom_add.'; return context; } this._indices = decode(this._indices); @@ -1515,6 +1608,10 @@ onnx.AttributeType = { TYPE_PROTOS: 14 }; +onnx.AttributeTypeFromSchema = { + +} + onnx.ModelContext = class { constructor(metadata, imageFormat) { @@ -1776,6 +1873,7 @@ onnx.GraphContext = class { outputs.push(new onnx.Parameter(output.name, list)); i += count; } + // console.log(schema) // console.log(node) node = new onnx.Node(this, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs); this._nodes.push(node); diff --git a/static/view.js b/static/view.js index 04b7980..3cbf873 100644 --- a/static/view.js +++ b/static/view.js @@ -457,23 +457,7 @@ view.View = class { } get activeGraph() { - // return Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; - var active_graph = Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; - // console.log(this._addedNode) - // console.log(active_graph) - - if (this.lastViewGraph) { - // console.log(this.lastViewGraph._addedNode) - for (const node of this.lastViewGraph._addedNode) { - // console.log(node) - var empty_node = active_graph.make_empty_node(node) - - - - } - } - - return active_graph + return Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null; } _updateGraph(model, graphs) { @@ -487,6 +471,7 @@ view.View = class { this.lastViewGraph = this._graph; const graph = this.activeGraph; + // console.log(graph.nodes) // console.log("_updateGraph is called"); return this._timeout(100).then(() => { @@ -1114,12 +1099,18 @@ view.Graph = class extends grapher.Graph { properties.set('domain', op_domain) properties.set('op_type', op_type) properties.set('name', modelNodeName) - // console.log(properties) - this._addedNode.push(new view.LightNodeInfo(properties)) - console.log(this._addedNode) - + // console.log(this._addedNode) + + // refresh + this.view._graphs[0].reset_custom_added_node() + for (const node_info of this._addedNode) { + // console.log(node) + this.view._graphs[0].make_custom_add_node(node_info) + } + // console.log(this.view._graphs[0].nodes) + this.view._updateGraph() }