the feature of `change node attribute` is basically done

1123
ZhangGe6 4 years ago
parent 2be1a4b604
commit 95f18e37b0

@ -8,7 +8,7 @@ import copy
import numpy as np
import onnx
import onnxruntime as rt
from utils import make_node
from utils import make_new_node, make_attr_changed_node
class onnxModifier:
def __init__(self, model_name, model_proto):
@ -117,12 +117,23 @@ class onnxModifier:
if node.output[i] == src_name:
node.output[i] = dst_name
def modify_node_attr(self, node_changed_attr):
# we achieve it by deleting the original node and make a (copied) new node
# print(node_changed_attr)
for node_name in node_changed_attr.keys():
orig_node = self.node_name2module[node_name]
attr_changed_node = make_attr_changed_node(orig_node, node_changed_attr[node_name])
self.graph.node.remove(self.node_name2module[node_name])
self.graph.node.append(attr_changed_node)
# update the node_name2module and initializer_name2module
self.gen_name2module_map()
def add_node(self, nodes_info, node_states):
for node_info in nodes_info.values():
if node_states[node_info['properties']['name']] == "Deleted":
continue
node = make_node(node_info)
node = make_new_node(node_info)
# print(node)
self.graph.node.append(node)
@ -131,17 +142,23 @@ class onnxModifier:
def modify(self, modify_info):
# print(modify_info['node_states'])
# print(modify_info['node_renamed_io'])
print(modify_info['node_changed_attr'])
# print(modify_info['added_node_info'])
self.remove_node_by_node_states(modify_info['node_states'])
self.modify_node_io_name(modify_info['node_renamed_io'])
self.modify_node_attr(modify_info['node_changed_attr'])
self.add_node(modify_info['added_node_info'], modify_info['node_states'])
def check_and_save_model(self, save_dir='./modified_onnx'):
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_path = os.path.join(save_dir, 'modified_' + self.model_name)
onnx.checker.check_model(self.model_proto)
# adding new node like self.add_node() and self.modify_node_attr() can not guarantee the nodes are topologically sorted
# so `onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted` will be invoked
# I turn off the onnx checker as a workaround.
# onnx.checker.check_model(self.model_proto)
onnx.save(self.model_proto, save_path)
def inference(self, x=None, output_names=None):
@ -241,7 +258,7 @@ if __name__ == "__main__":
# print(inp.name)
onnx_modifier.check_and_save_model()
remove_node_by_node_states()
# remove_node_by_node_states()
def test_modify_node_io_name():
@ -257,7 +274,17 @@ if __name__ == "__main__":
onnx_modifier.inference()
onnx_modifier.check_and_save_model()
# test_add_node()
def test_change_node_attr():
# changed_attr = {'Clip_3': {'max': 5}}
changed_attr = {'Conv_2': {'group': 64}}
onnx_modifier.modify_node_attr(changed_attr)
onnx_modifier.check_and_save_model()
test_change_node_attr()

@ -215,9 +215,9 @@ host.BrowserHost = class {
const downloadButton = this.document.getElementById('download-graph');
downloadButton.addEventListener('click', () => {
console.log(this._view._graph._addedNode)
console.log(this._view._graph._renameMap)
// https://healeycodes.com/talking-between-languages
// console.log(this._view._graph._addedNode)
// console.log(this._view._graph._renameMap)
// // https://healeycodes.com/talking-between-languages
fetch('/download', {
// Declare what type of data we're sending
headers: {
@ -228,6 +228,7 @@ host.BrowserHost = class {
body: JSON.stringify({
'node_states' : this.mapToObjectRec(this._view._graph._modelNodeName2State),
'node_renamed_io' : this.mapToObjectRec(this._view._graph._renameMap),
'node_changed_attr' : this.mapToObjectRec(this._view._graph._changedAttributes),
'added_node_info' : this.mapToObjectRec(this.parseLightNodeInfo2Map(this._view._graph._addedNode))
}

@ -21,6 +21,7 @@ grapher.Graph = class {
this._pathArgumentNames = new Set(); // the name of arguments which occurs in both sides of an edge
this._renameMap = new Map();
this._changedAttributes = new Map();
this._addedNode = new Map();
}

@ -658,7 +658,7 @@ class NodeAttributeView {
this._element.appendChild(this._expander);
}
const value = this._attribute.value;
// console.log(this._attribute.name, value, type)
console.log(this._attribute.name, value, type)
switch (type) {
case 'graph': {
const line = this._host.document.createElement('div');
@ -699,6 +699,7 @@ class NodeAttributeView {
attr_input.setAttribute("value", content ? content : 'undefined');
attr_input.addEventListener('input', (e) => {
// console.log(e.target.value);
// console.log(this.parse_value(e.target.value, type))
this._host._view._graph.changeNodeAttribute(this._modelNodeName, this._attributeName, this.parse_value(e.target.value, type));
// console.log(this._host._view._graph._renameMap);
});

@ -462,6 +462,7 @@ view.View = class {
if (active_graph && this.lastViewGraph) {
this.refreshAddedNode()
this.refreshNodeArguments()
}
return active_graph
@ -582,6 +583,7 @@ view.View = class {
// console.log('node state of lastViewGraph is loaded')
viewGraph._modelNodeName2State = this.lastViewGraph._modelNodeName2State;
viewGraph._renameMap = this.lastViewGraph._renameMap;
viewGraph._changedAttributes = this.lastViewGraph._changedAttributes;
viewGraph._addedNode = this.lastViewGraph._addedNode;
viewGraph._add_nodeKey = this.lastViewGraph._add_nodeKey
// console.log(viewGraph._renameMap);
@ -906,15 +908,6 @@ view.View = class {
// console.log(this._graphs[0])
// console.log(this._graphs[0]._nodes)
// for (const node_name of this.lastViewGraph._renameMap.keys()) {
// var rename_map = this.lastViewGraph._renameMap.get(node_name)
// var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name)
// console.log(node)
// console.log(rename_map)
// }
for (var node of this._graphs[0]._nodes) {
// this node has some changed arguments
// console.log(node)
@ -978,8 +971,20 @@ view.View = class {
// console.log(this._graphs[0]._context._arguments)
for (const node_name of this.lastViewGraph._changedAttributes.keys()) {
var attr_change_map = this.lastViewGraph._changedAttributes.get(node_name)
var node = this.lastViewGraph._modelNodeName2ModelNode.get(node_name)
// console.log(node)
// console.log(attr_change_map)
// for (const attr of node._attributes) {
for (var i = 0; i < node._attributes.length; ++i) {
if (attr_change_map.get(node._attributes[i].name)) {
node._attributes[i]._value = attr_change_map.get(node._attributes[i].name)
}
}
}
}
@ -1330,6 +1335,16 @@ view.Graph = class extends grapher.Graph {
this._addedNode.get(modelNodeName).attributes.set(attributeName, targetValue)
}
// console.log(this._addedNode)
else { // for the nodes in the original model
if (!this._changedAttributes.get(modelNodeName)) {
this._changedAttributes.set(modelNodeName, new Map());
}
this._changedAttributes.get(modelNodeName).set(attributeName, targetValue)
// console.log(this._changedAttributes)
}
this.view._updateGraph()
}
@ -1348,7 +1363,7 @@ view.Graph = class extends grapher.Graph {
}
// console.log(this._addedNode)
// else { // for the nodes in the original model
// else {
// if (!this._renameMap.get(modelNodeName)) {
// this._renameMap.set(modelNodeName, new Map());
// }
@ -1356,7 +1371,7 @@ view.Graph = class extends grapher.Graph {
// // console.log(this._renameMap)
// }
else {
else { // for the nodes in the original model
if (!this._renameMap.get(modelNodeName)) {
this._renameMap.set(modelNodeName, new Map());
}

@ -1,6 +1,8 @@
import onnx
from onnx import AttributeProto
def make_node(node_info):
def make_new_node(node_info):
name = node_info['properties']['name']
op_type = node_info['properties']['op_type']
# attributes = node_info['attributes']
@ -32,3 +34,51 @@ def make_node(node_info):
# print(node)
return node
def make_attr_changed_node(node, attr_change_info):
# convert the changed attribute value into the type that is consistent with the original attribute
# because AttributeProto is constructed barely based on the input value
# https://github.com/onnx/onnx/blob/4e24b635c940801555bee574b4eb3a34cab9acd5/onnx/helper.py#L472
def make_type_value(value, AttributeProto_type):
# https://github.com/protocolbuffers/protobuf/blob/main/python/google/protobuf/internal/enum_type_wrapper.py#L60
attr_type = AttributeProto.AttributeType.Name(AttributeProto_type)
if attr_type == "FLOAT":
return float(value)
elif attr_type == "INT":
return int(value)
elif attr_type == "STRING":
return str(value)
elif attr_type == "FLOATS":
return [float(v) for v in value]
elif attr_type == "INTS":
return [int(v) for v in value]
elif attr_type == "STRINGS":
return [str(v) for v in value]
else:
raise RuntimeError("type {} is not considered in current version. \
You can kindly report an issue for this problem. Thanks!".format(attr_type))
new_attr = dict()
for attr in node.attribute:
# print(onnx.helper.get_attribute_value(attr))
if attr.name in attr_change_info.keys():
new_attr[attr.name] = make_type_value(attr_change_info[attr.name], attr.type)
else:
# https://github.com/onnx/onnx/blob/4e24b635c940801555bee574b4eb3a34cab9acd5/onnx/helper.py#L548
new_attr[attr.name] = onnx.helper.get_attribute_value(attr)
# print(new_attr)
# print(node.input, node.output)
# print(node)
node = onnx.helper.make_node(
op_type=node.op_type,
inputs=node.input,
outputs=node.output,
name=node.name,
**new_attr
)
# print(node)
return node
Loading…
Cancel
Save