diff --git a/docs/edit_initializer.gif b/docs/edit_initializer.gif new file mode 100644 index 0000000..7a749dc Binary files /dev/null and b/docs/edit_initializer.gif differ diff --git a/docs/update_log.md b/docs/update_log.md index b6815a4..f9740c4 100644 --- a/docs/update_log.md +++ b/docs/update_log.md @@ -1,5 +1,9 @@ # onnx-modifier update log +## 20221026 +- support scrolling to the last page position when updating model +- support editing initializer feature + ## 20220921 - add argparse module for custom config - fix model parsing issue when the model is load in a drag-and-drop way @@ -18,7 +22,7 @@ ## 20220621 - add Dockerfile - - thanks to [fengwang](https://github.com/fengwang) + - thanks to [fengwang](https://github.com/fengwang) and [this PR](https://github.com/ZhangGe6/onnx-modifier/pulls?q=is%3Apr+is%3Aclosed) ## 20220620 diff --git a/onnx_modifier.py b/onnx_modifier.py index e4bf5f3..321aafc 100644 --- a/onnx_modifier.py +++ b/onnx_modifier.py @@ -179,6 +179,7 @@ class onnxModifier: for node_info in nodes_info.values(): if node_states[node_info['properties']['name']] == "Deleted": continue + print(node_info) node = make_new_node(node_info) # print(node) @@ -261,6 +262,7 @@ class onnxModifier: inference_session = rt.InferenceSession(model_proto_bytes) if not x: + np.random.seed(0) x = np.random.randn(*input_shape).astype(np.float32) if not output_names: output_name = self.graph.node[-1].output[0] @@ -272,11 +274,14 @@ class onnxModifier: input_name = inference_session.get_inputs()[0].name out = inference_session.run(output_names, {input_name: x})[0] print(out.shape) + # print(out[0][0][0][0]) if __name__ == "__main__": # model_path = "C:\\Users\\ZhangGe\\Desktop\\resnet18-v2-7.onnx" # model_path = "C:\\Users\\ZhangGe\\Desktop\\movenet_lightning.onnx" - model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx" + model_path = "C:\\Users\\ZhangGe\\Desktop\\modified_test_edit_init.onnx" + # model_path = "C:\\Users\\ZhangGe\\Desktop\\test_edit_init.onnx" onnx_modifier = onnxModifier.from_model_path(model_path) def explore_basic(): @@ -346,7 +351,7 @@ if __name__ == "__main__": onnx_modifier.add_nodes(node_info) - onnx_modifier.inference() + onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=["onnx::Transpose_368"]) onnx_modifier.check_and_save_model() # test_add_node() @@ -360,8 +365,8 @@ if __name__ == "__main__": # test_change_node_attr() def test_inference(): - onnx_modifier.inference() - # test_inference() + onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=["onnx::Transpose_368"]) + test_inference() def test_add_output(): # print(onnx_modifier.graph.output) @@ -394,14 +399,11 @@ if __name__ == "__main__": # test_modify_primary_initializer() def test_modify_new_initializer(): - modify_info = {'node_states': {'input': 'Exist', 'Conv_0': 'Exist', 'LeakyRelu_1': 'Exist', 'Conv_2': 'Exist', 'LeakyRelu_3': 'Exist', 'Conv_4': 'Exist', 'LeakyRelu_5': 'Exist', 'Conv_6': 'Exist', 'LeakyRelu_7': 'Exist', 'Conv_8': 'Exist', 'LeakyRelu_9': 'Exist', 'Conv_10': 'Exist', 'Conv_11': 'Exist', 'LeakyRelu_12': 'Exist', 'Conv_13': 'Exist', 'Conv_14': 'Exist', 'LeakyRelu_15': 'Exist', 'Conv_16': 'Exist', 'Concat_17': 'Exist', 'LeakyRelu_18': 'Exist', 'Conv_19': 'Exist', 'Sigmoid_20': 'Exist', 'Mul_22': 'Exist', 'Conv_23': 'Exist', 'LeakyRelu_24': 'Exist', 'Conv_25': 'Exist', 'Conv_26': 'Exist', 'LeakyRelu_27': 'Exist', 'Conv_28': 'Exist', 'Add_29': 'Exist', 'Conv_30': 'Exist', 'Conv_31': 'Exist', 'LeakyRelu_32': 'Exist', 'Conv_33': 'Exist', 'Conv_34': 'Exist', 'LeakyRelu_35': 'Exist', 'Conv_36': 'Exist', 'Concat_37': 'Exist', 'LeakyRelu_38': 'Exist', 'Conv_39': 'Exist', - 'Conv_40': 'Exist', 'LeakyRelu_41': 'Exist', 'Conv_42': 'Exist', 'LeakyRelu_43': 'Exist', 'Conv_44': 'Exist', 'Conv_45': 'Exist', 'LeakyRelu_46': 'Exist', 'Concat_47': 'Exist', 'Reshape_49': 'Exist', 'out_onnx::Transpose_368': 'Exist', 'custom_added_Reshape0': 'Exist', 'out_custom_output_2': 'Exist'}, 'node_renamed_io': {}, 'node_changed_attr': {}, 'added_node_info': {'custom_added_Reshape0': {'properties': {'domain': 'ai.onnx', 'op_type': 'Reshape', 'name': 'custom_added_Reshape0'}, 'attributes': {}, 'inputs': {'data': ['onnx::Transpose_368'], 'shape': ['custom_input_1']}, 'outputs': {'reshaped': ['custom_output_2']}}}, 'added_outputs': {'0': 'custom_output_2'}, 'rebatch_info': {}, 'changed_initializer': {'custom_input_1': ['int64', '[1, 2, 32, 24, 6]']}} + modify_info = {'node_states': {'input': 'Exist', 'Conv_0': 'Exist', 'LeakyRelu_1': 'Exist', 'Conv_2': 'Exist', 'LeakyRelu_3': 'Exist', 'Conv_4': 'Exist', 'LeakyRelu_5': 'Exist', 'Conv_6': 'Exist', 'LeakyRelu_7': 'Exist', 'Conv_8': 'Exist', 'LeakyRelu_9': 'Exist', 'Conv_10': 'Exist', 'Conv_11': 'Exist', 'LeakyRelu_12': 'Exist', 'Conv_13': 'Exist', 'Conv_14': 'Exist', 'LeakyRelu_15': 'Exist', 'Conv_16': 'Exist', 'Concat_17': 'Exist', 'LeakyRelu_18': 'Exist', 'Conv_19': 'Exist', 'Sigmoid_20': 'Exist', 'Mul_22': 'Exist', 'Conv_23': 'Exist', 'LeakyRelu_24': 'Exist', 'Conv_25': 'Exist', 'Conv_26': 'Exist', 'LeakyRelu_27': 'Exist', 'Conv_28': 'Exist', 'Add_29': 'Exist', 'Conv_30': 'Exist', 'Conv_31': 'Exist', 'LeakyRelu_32': 'Exist', 'Conv_33': 'Exist', 'Conv_34': 'Exist', 'LeakyRelu_35': 'Exist', 'Conv_36': 'Exist', 'Concat_37': 'Exist', 'LeakyRelu_38': 'Exist', 'Conv_39': 'Exist', 'Conv_40': 'Exist', 'LeakyRelu_41': 'Exist', 'Conv_42': 'Exist', 'LeakyRelu_43': 'Exist', 'Conv_44': 'Exist', 'Conv_45': 'Exist', 'LeakyRelu_46': 'Exist', 'Concat_47': 'Exist', 'Reshape_49': 'Exist', 'out_onnx::Transpose_368': 'Exist', 'custom_added_Slice0': 'Exist'}, 'node_renamed_io': {}, 'node_changed_attr': {}, 'added_node_info': {'custom_added_Slice0': {'properties': {'domain': 'ai.onnx', 'op_type': 'Slice', 'name': 'custom_added_Slice0'}, 'attributes': {}, 'inputs': {'data': ['custom_input_0'], 'starts': ['custom_input_1'], 'ends': ['custom_input_2'], 'output': ['custom_output_5']}, 'outputs': {}}}, 'added_outputs': {}, 'rebatch_info': {}, 'changed_initializer': {'custom_input_1': ['int64', '[0,0,0,0]'], 'custom_input_2': ['int64', '[1,1,192,192]']}} onnx_modifier.modify(modify_info) onnx_modifier.check_and_save_model() onnx_modifier.inference(input_shape=[1, 1, 192, 192], output_names=['custom_output_2']) - print(onnx_modifier.initializer_name2module.keys()) - for initializer in onnx_modifier.initializer: - print(f"Tensor Name: {initializer.name}, Data Type: {initializer.data_type}, Shape: {initializer.dims}") - - test_modify_new_initializer() - \ No newline at end of file + # print(onnx_modifier.initializer_name2module.keys()) + # for initializer in onnx_modifier.initializer: + # print(f"Tensor Name: {initializer.name}, Data Type: {initializer.data_type}, Shape: {initializer.dims}") + # test_modify_new_initializer() \ No newline at end of file diff --git a/readme.md b/readme.md index 8e9ecca..8304090 100644 --- a/readme.md +++ b/readme.md @@ -6,7 +6,7 @@ English | [简体中文](readme_zh-CN.md) To edit an ONNX model, One common way is to visualize the model graph, and edit it using ONNX Python API. This works fine. However, we have to code to edit, then visualize to check. The two processes may iterate for many times, which is time-consuming. 👋 -What if we have a tool, which allow us to **edit and preview the editing effect in a totally visualization fashion**? +What if we have a tool, which allows us to **edit and preview the editing effect in a totally visualization fashion**? Then `onnx-modifier` comes. With it, we can focus on editing the model graph in the visualization pannel. All the editing information will be summarized and processed by Python ONNX API automatically at last. Then our time can be saved! 🚀 @@ -24,6 +24,7 @@ Currently, the following editing operations are supported: - [x] Edit the attribute of nodes - [x] Add new nodes (experimental) - [x] Change batch size +- [x] Edit model initializers Here is the [update log](./docs/update_log.md) and [TODO list](./docs/todo_list.md). @@ -176,19 +177,13 @@ Note there is an `Add node` button, following with a selector elements on the to The following are some notes for this feature: -1. :warning: Currently, adding nodes with initializer (such as weight parameters) is not supported (such as `Conv`, `BatchNormalization`). Adding nodes without initializer are tested and work as expected in my tested case (such as `Flatten`, `ArgMax`, `Concat`). +1. By clicking the `?` in the `NODE PROPERTIES -> type` element, or the `+` in each `Attribute` element, we can get some reference to help us fill the node information. -2. Click the selector and type the first letter for the new node type (`f` for `Flatten` node for example), we can be quickly navigated to the node type. +2. It is suggested to fill all of the `Attribute`, without leaving them as `undefined`. The default value may not be supported well in the current version. -3. By clicking the `?` in the `NODE PROPERTIES -> type` element, or the `+` in each `Attribute` element, we can get some reference to help us fill the node information. +3. For the `Attribute` with type `list`, items are split with '`,`' (comma). Note that `[]` is not needed. -4. It is suggested to fill all of the `Attribute`, without leaving them as `undefined`. The default value may not be supported well in the current version. - -5. For the `Attribute` with type `list`, items are split with '`,`' (comma) - -6. For the `Inputs/Outputs` with type `list`, it is forced to be at most 8 elements in the current version. If the actual inputs/outputs number is less than 8, we can leave the unused items with the name starting with `list_custom`, and they will be automatically omitted. - -7. This feature is experimentally supported now and may be not very robust. So any issues are warmly welcomed if some unexpected results are encountered. +4. For the `Inputs/Outputs` with type `list`, it is forced to be at most 8 elements in the current version. If the actual inputs/outputs number is less than 8, we can leave the unused items with the name starting with `list_custom`, and they will be automatically omitted. ## Change batch size `onnx-modifier` supports editing batch size now. Both `Dynamic batch size` and `Fixed batch size` modes are supported. @@ -201,6 +196,13 @@ Note the differences between `fixed batch size inference` and `dynamic batch siz > - When running a model with only fixed dimensions, the ONNX Runtime will prepare and optimize the graph for execution when constructing the Inference Session. > - when the model has dynamic dimensions like batch size, the ONNX Runtime may instead cache optimized graphs for specific batch sizes when inputs are first encountered for that batch size. +## Edit model initializers +Sometimes we want to edit the values which are stored in model initializers, such as the weight/bias of a convolution layer and the shape parameter of a `Reshape` node. `onnx-modifier` supports this feature now! Input a new value for the initializer in the invoked sidebar and click Download, then we are done. + + + +> Note: For the newly added node, we should also input the datatype of the initializer. (If we are not sure what the datatype is, click `NODE PROPERTIES->type->?`, we may get some clues.) + # Sample models For quick testing, some typical sample models are provided as following. Most of them are from [onnx model zoo](https://github.com/onnx/models) diff --git a/readme_zh-CN.md b/readme_zh-CN.md index be18f97..23e8b66 100644 --- a/readme_zh-CN.md +++ b/readme_zh-CN.md @@ -23,8 +23,9 @@ - [x] 修改模型输入输出名 - [x] 增加模型输出节点 - [x] 编辑节点属性值 -- [x] 增加新节点(beta) +- [x] 增加新节点 - [x] 修改模型batch size +- [x] 修改模型initializers `onnx-modifier`基于流行的模型可视化工具 [Netron](https://github.com/lutzroeder/netron) 和轻量级Web应用框架 [flask](https://github.com/pallets/flask) 开发。希望它能给社区带来一些贡献~ @@ -129,7 +130,7 @@ -## 增加新节点(beta) +## 增加新节点 有时候我们希望向模型中添加新节点。`onnx-modifier`已开始支持该功能。 @@ -148,13 +149,10 @@ 以下是该功能的一些提醒和小tip: -1. 在当前版本中,是不支持添加含有参数的节点的(比如`Conv`, `BatchNormalization`)。其他大多数节点,在我的测试中,可正确添加(比如`Flatten`, `ArgMax`, `Concat`)。 -2. 点击selector,输入要添加的节点的首字母(比如`Flatten`的`f`),可帮我们定位到以该字母开头的节点列表区域,加快检索速度。 -3. 点击节点侧边栏的`NODE PROPERTIES`的`type`框右侧的`?`,和节点属性框右侧的`+`,可以显示关于当前节点类型/属性值的参考信息。 -4. 为确保正确性,节点的各属性值建议全部填写(而不是留着`undefined`)。默认值在当前版本可能支持得还不够好。 -5. 如果一个属性值是列表类型,则各元素之间使用‘`,`’分隔。 -6. 在当前版本中,如果一个节点的输入/输出是一个列表类型(如`Concat`),限制最多显示8个。如果一个节点实际输入/输出小于8个,则填写对应数目的输入输出即可,多出来的应以`list_custom`开头,它们会在后续处理中自动被忽略。 -7. 这个功能还处在开发中,可能会不够鲁棒。所以如果大家在实际使用时碰到问题,非常欢迎提issue! +1. 点击节点侧边栏的`NODE PROPERTIES`的`type`框右侧的`?`,和节点属性框右侧的`+`,可以显示关于当前节点类型/属性值的参考信息。 +2. 为确保正确性,节点的各属性值建议全部填写(而不是留着`undefined`)。默认值在当前版本可能支持得还不够好。 +3. 如果一个属性值是列表类型,则各元素之间使用‘`,`’分隔,无需'[]'。 +4. 在当前版本中,如果一个节点的输入/输出是一个列表类型(如`Concat`),限制最多显示8个。如果一个节点实际输入/输出小于8个,则填写对应数目的输入输出即可,多出来的应以`list_custom`开头,它们会在后续处理中自动被忽略。 ## 修改模型batch size 动态batch size和固定batch size均已支持。 @@ -163,6 +161,13 @@ +## 修改模型initializers +有时候我们要修改一些保存在模型initializer中的数值,比如卷积层的权重/偏置参数,`Reshape`节点的`shape`参数等。使用`onnx-modifier`,这一操作将非常简单:在对应节点侧边栏的initializer中键入新的数值,点击`Download`即可。 + + + +> 如果要修改我们**新增加的**节点的initializer,除了键入其数值之外,还要键入其数据类型。(如果我们不确定数据类型,可以点击`NODE PROPERTIES->type->?`,在弹出的节点的详细介绍界面中,可能会找到线索。) + `onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用,提issue,如果有帮助的话,感谢给个:star:~ # 示例模型文件 diff --git a/static/index.js b/static/index.js index ad48f67..d2f49db 100644 --- a/static/index.js +++ b/static/index.js @@ -229,7 +229,8 @@ host.BrowserHost = class { 'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State), 'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap), 'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes), - 'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode)), + 'added_node_info' : this.mapToObjectRec(this.parseAddedLightNodeInfo2Map(this._view._graph._addedNode, + this._view._graph._initializerEditInfo)), 'added_outputs' : this.arrayToObject(this.process_added_outputs(this._view._graph._addedOutputs, this._view._graph._renameMap, this._view._graph._modelNodeName2State)), 'rebatch_info' : this.mapToObjectRec(this._view._graph._reBatchInfo), @@ -720,20 +721,74 @@ host.BrowserHost = class { } // convert view.LightNodeInfo to Map object for easier transmission to Python backend - parseLightNodeInfo2Map(nodes_info) { + parseAddedLightNodeInfo2Map(nodes_info, initializer_info) { + console.log(nodes_info) + console.log(initializer_info) var res_map = new Map() for (const [modelNodeName, node_info] of nodes_info) { var node_info_map = new Map() node_info_map.set('properties', node_info.properties) node_info_map.set('attributes', node_info.attributes) - node_info_map.set('inputs', node_info.inputs) - node_info_map.set('outputs', node_info.outputs) + // skip the input and output which is optional and has no initializer value + var inputs = new Map() + // console.log(node_info) + // console.log(node_info.inputs) + for (var [input_name, arg_list] of node_info.inputs) { + var filtered_arg_list = [] + for (var arg of arg_list) { + var arg_name = arg[0], arg_optional = arg[1]; + if (arg_optional) { + if (!initializer_info.get(arg_name) || initializer_info.get(arg_name) == "") { + continue; + } + } + filtered_arg_list.push(arg_name); + } + if (filtered_arg_list.length > 0) { + inputs.set(input_name, filtered_arg_list) + } + } + // console.log(inputs) + node_info_map.set('inputs', inputs) + + var outputs = new Map() + for (var [output_name, arg_list] of node_info.outputs) { + var filtered_arg_list = [] + for (var arg of arg_list) { + var arg_name = arg[0], arg_optional = arg[1]; + if (arg_optional) { + if (!initializer_info.get(arg_name) || initializer_info.get(arg_name) == "") { + continue; + } + } + filtered_arg_list.push(arg_name); + } + if (filtered_arg_list.length > 0) { + outputs.set(output_name, filtered_arg_list) + } + } + node_info_map.set('outputs', outputs) + res_map.set(modelNodeName, node_info_map) } + // console.log(res_map) return res_map } + + // rename the initializer if its corresponding argument name is changed (for primitive nodes) + process_initializer(initializer_info, rename_map) { + for (const [node_name, rename_pair] of rename_map) { + for (const [arg_orig_name, arg_new_name] of rename_pair) { + if (initializer_info.has(arg_orig_name)) { + var init_val = initializer_info.get(arg_orig_name) + initializer_info.set(arg_new_name, init_val) + initializer_info.delete(arg_orig_name) + } + } + } + } }; host.Dropdown = class { diff --git a/static/onnx.js b/static/onnx.js index 5e6f23c..1ab8904 100644 --- a/static/onnx.js +++ b/static/onnx.js @@ -549,7 +549,7 @@ onnx.Graph = class { if (input.list) { for (let j = 0; j < max_custom_add_input_num; ++j) { if (node_info_input && node_info_input[j]) { - var arg_name = node_info_input[j] + var arg_name = node_info_input[j][0] // [arg.name, arg.is_optional] } else { var arg_name = 'list_custom_input_' + (this._custom_add_node_io_idx++).toString() @@ -559,7 +559,7 @@ onnx.Graph = class { } else { if (node_info_input && node_info_input[0]) { - var arg_name = node_info_input[0] + var arg_name = node_info_input[0][0] // [arg.name, arg.is_optional] } else { var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString() @@ -569,6 +569,9 @@ onnx.Graph = class { for (var arg of arg_list) { arg.is_custom_added = true; + if (input.option && input.option == 'optional') { + arg.is_optional = true; + } } inputs.push(new onnx.Parameter(input.name, arg_list)); } @@ -582,7 +585,7 @@ onnx.Graph = class { if (output.list) { for (let j = 0; j < max_custom_add_output_num; ++j) { if (node_info_output && node_info_output[j]) { - var arg_name = node_info_output[j] + var arg_name = node_info_output[j][0] } else { var arg_name = 'list_custom_output_' + (this._custom_add_node_io_idx++).toString() @@ -592,7 +595,7 @@ onnx.Graph = class { } else { if (node_info_output && node_info_output[0]) { - var arg_name = node_info_output[0] + var arg_name = node_info_output[0][0] } else { var arg_name = 'custom_output_' + (this._custom_add_node_io_idx++).toString() @@ -603,6 +606,9 @@ onnx.Graph = class { for (var arg of arg_list) { arg.is_custom_added = true; + if (output.option && output.option == 'optional') { + arg.is_optional = true; + } } outputs.push(new onnx.Parameter(output.name, arg_list)); } @@ -687,6 +693,7 @@ onnx.Argument = class { this.original_name = original_name || name; this.is_custom_added = false; + this.is_optional = false; } diff --git a/static/view-sidebar.js b/static/view-sidebar.js index edc77a8..9148bc5 100644 --- a/static/view-sidebar.js +++ b/static/view-sidebar.js @@ -265,7 +265,7 @@ sidebar.NodeSidebar = class { newNameElement.setAttribute('value', this._host._view._graph._renameMap.get(this._modelNodeName).get(argument.name)); } newNameElement.addEventListener('input', (e) => { - console.log(e.target.value); + // console.log(e.target.value); this._host._view._graph.recordRenameInfo(this._modelNodeName, argument.name, e.target.value); // console.log(this._host._view._graph._renameMap); }); @@ -1009,14 +1009,17 @@ sidebar.ArgumentView = class { editInitializerVal.innerHTML = 'This is an initializer, you can input a new value for it here:'; this._element.appendChild(editInitializerVal); - var inputInitializerVal = document.createElement("INPUT"); + var inputInitializerVal = document.createElement("textarea"); inputInitializerVal.setAttribute("type", "text"); - inputInitializerVal.setAttribute("size", "42"); + inputInitializerVal.rows = 1; + inputInitializerVal.cols = 44; + // reload the last value var orig_arg_name = this._host._view._graph.getOriginalName(this._param_type, this._modelNodeName, this._param_index, this._arg_index) if (this._host._view._graph._initializerEditInfo.get(orig_arg_name)) { // [type, value] - inputInitializerVal.setAttribute("value", this._host._view._graph._initializerEditInfo.get(orig_arg_name)[1]); + // inputInitializerVal.setAttribute("value", this._host._view._graph._initializerEditInfo.get(orig_arg_name)[1]); + inputInitializerVal.innerHTML = this._host._view._graph._initializerEditInfo.get(orig_arg_name)[1]; } inputInitializerVal.addEventListener('input', (e) => { @@ -1027,6 +1030,12 @@ sidebar.ArgumentView = class { } if (this._argument.is_custom_added) { + if (this._argument.is_optional) { + const isOptionalLine = this._host.document.createElement('div'); + isOptionalLine.className = 'sidebar-view-item-value-line-border'; + isOptionalLine.innerHTML = 'optional: true'; + this._element.appendChild(isOptionalLine); + } var new_init_val = "", new_init_type = ""; // ====== input value ======> const editInitializerVal = this._host.document.createElement('div'); @@ -1034,10 +1043,10 @@ sidebar.ArgumentView = class { editInitializerVal.innerHTML = 'If this is an initializer, you can input new value for it here:'; this._element.appendChild(editInitializerVal); - var inputInitializerVal = document.createElement("INPUT"); + var inputInitializerVal = document.createElement("textarea"); inputInitializerVal.setAttribute("type", "text"); - inputInitializerVal.setAttribute("size", "42"); - // this._element.appendChild(inputInitializerVal); + inputInitializerVal.rows = 1; + inputInitializerVal.cols = 44; inputInitializerVal.addEventListener('input', (e) => { // console.log(e.target.value) @@ -1050,18 +1059,21 @@ sidebar.ArgumentView = class { // ====== input type ======> const editInitializerType = this._host.document.createElement('div'); editInitializerType.className = 'sidebar-view-item-value-line-border'; - editInitializerType.innerHTML = 'and input its type for it here (see properties->type->? for more info):'; + editInitializerType.innerHTML = 'and input its type for it here ' + '(see properties->type->?' + '' + ' for more info):'; this._element.appendChild(editInitializerType); - var inputInitializerType = document.createElement("INPUT"); + var inputInitializerType = document.createElement("textarea"); inputInitializerType.setAttribute("type", "text"); - inputInitializerType.setAttribute("size", "42"); + inputInitializerType.rows = 1; + inputInitializerType.cols = 44; - var arg_name = this._host._view._graph._addedNode.get(this._modelNodeName).inputs.get(this._parameterName)[this._arg_index] + var arg_name = this._host._view._graph._addedNode.get(this._modelNodeName).inputs.get(this._parameterName)[this._arg_index][0] // [arg.name, arg.is_optional] if (this._host._view._graph._initializerEditInfo.get(arg_name)) { // [type, value] - inputInitializerType.setAttribute("value", this._host._view._graph._initializerEditInfo.get(arg_name)[0]); - inputInitializerVal.setAttribute("value", this._host._view._graph._initializerEditInfo.get(arg_name)[1]); + // inputInitializerType.setAttribute("value", this._host._view._graph._initializerEditInfo.get(arg_name)[0]); + // inputInitializerVal.setAttribute("value", this._host._view._graph._initializerEditInfo.get(arg_name)[1]); + inputInitializerType.innerHTML = this._host._view._graph._initializerEditInfo.get(arg_name)[0]; + inputInitializerVal.innerHTML = this._host._view._graph._initializerEditInfo.get(arg_name)[1]; } inputInitializerType.addEventListener('input', (e) => { diff --git a/static/view.js b/static/view.js index 665f0b6..142d301 100644 --- a/static/view.js +++ b/static/view.js @@ -588,6 +588,7 @@ view.View = class { viewGraph._initializerEditInfo = this.lastViewGraph._initializerEditInfo; // console.log(viewGraph._renameMap); // console.log(viewGraph._modelNodeName2State) + // console.log(viewGraph._initializerEditInfo) const container = this._getElementById('graph'); this.lastScrollLeft = container.scrollLeft; @@ -868,6 +869,7 @@ view.View = class { } } + // TODO: add filter feature like here: https://www.w3schools.com/howto/howto_js_dropdown.asp UpdateAddNodeDropDown() { // update dropdown supported node lost var addNodeDropdown = this._host.document.getElementById('add-node-dropdown'); @@ -890,19 +892,19 @@ view.View = class { // console.log(node) for (const input of node.inputs) { - var input_list_names = [] + var arg_list_info = [] for (const arg of input._arguments) { - input_list_names.push(arg.name) + arg_list_info.push([arg.name, arg.is_optional]) } - this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, input_list_names) + this.lastViewGraph._addedNode.get(modelNodeName).inputs.set(input.name, arg_list_info) } for (const output of node.outputs) { - var output_list_names = [] + var arg_list_info = [] for (const arg of output._arguments) { - output_list_names.push(arg.name) + arg_list_info.push([arg.name, arg.is_optional]) } - this.lastViewGraph._addedNode.get(modelNodeName).outputs.set(output.name, output_list_names) + this.lastViewGraph._addedNode.get(modelNodeName).outputs.set(output.name, arg_list_info) } } @@ -1402,19 +1404,26 @@ view.Graph = class extends grapher.Graph { // changeNodeInputOutput(modelNodeName, parameterName, arg_index, targetValue) { if (this._addedNode.has(modelNodeName)) { // for custom added node if (this._addedNode.get(modelNodeName).inputs.has(parameterName)) { - this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] = targetValue + var arg_name = this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index][0] // [arg.name, arg.is_optional] + // update the corresponding initializer name + if (this._initializerEditInfo.has(arg_name)) { + var init_val = this._initializerEditInfo.get(arg_name); + this._initializerEditInfo.set(targetValue, init_val) + this._initializerEditInfo.delete(arg_name) + } + this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index][0] = targetValue } + // console.log(this._initializerEditInfo) if (this._addedNode.get(modelNodeName).outputs.has(parameterName)) { - this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index] = targetValue + this._addedNode.get(modelNodeName).outputs.get(parameterName)[arg_index][0] = targetValue } - // this.view._updateGraph() // otherwise the changes can not be updated without manully updating graph } // console.log(this._addedNode) else { // for the nodes in the original model var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index) - console.log(orig_arg_name) + // console.log(orig_arg_name) if (!this._renameMap.get(modelNodeName)) { this._renameMap.set(modelNodeName, new Map()); @@ -1427,20 +1436,13 @@ view.Graph = class extends grapher.Graph { } changeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) { - if (this._addedNode.has(modelNodeName)) { // for custom added node - } - // console.log(this._addedNode) - - else { // for the nodes in the original model - var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index) - this._initializerEditInfo.set(orig_arg_name, [type, targetValue]); - } - + var orig_arg_name = this.getOriginalName(param_type, modelNodeName, param_index, arg_index) + this._initializerEditInfo.set(orig_arg_name, [type, targetValue]) this.view._updateGraph() } changeAddedNodeInitializer(modelNodeName, parameterName, param_type, param_index, arg_index, type, targetValue) { - var arg_name = this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index] + var arg_name = this._addedNode.get(modelNodeName).inputs.get(parameterName)[arg_index][0] this._initializerEditInfo.set(arg_name, [type, targetValue]); this.view._updateGraph() } diff --git a/utils/parse_tools.py b/utils/parse_tools.py index 5840f72..523f574 100644 --- a/utils/parse_tools.py +++ b/utils/parse_tools.py @@ -24,6 +24,7 @@ def parse_tensor(tensor_str, tensor_type): return None tensor_str = tensor_str.replace(" ", "") + tensor_str = tensor_str.replace("\n", "") stk = [] for c in tensor_str: # '[' ',' ']' '.' '-' or value if c == ",":