You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
import numpy as np
|
|
from typing import cast
|
|
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
|
|
|
|
def parse_value(value_str, value_type):
|
|
if value_type.startswith('int'):
|
|
return int(value_str)
|
|
elif value_type.startswith('float'):
|
|
return float(value_str)
|
|
else:
|
|
raise RuntimeError("type {} is not considered in current version. \
|
|
You can kindly report an issue for this problem. Thanks!".format(value_type))
|
|
|
|
# parse numpy values from string
|
|
def parse_tensor(tensor_str, tensor_type):
|
|
def extract_val():
|
|
num_str = ""
|
|
while (len(stk) > 0) and (type(stk[-1]) == str and ord('0') <= ord(stk[-1]) <= ord('9') or stk[-1] in ['+', '-', '.', 'e', 'E']):
|
|
num_str = stk.pop() + num_str
|
|
|
|
if len(num_str) > 0:
|
|
return parse_value(num_str, tensor_type)
|
|
else:
|
|
return None
|
|
|
|
tensor_str = tensor_str.replace(" ", "")
|
|
tensor_str = tensor_str.replace("\n", "")
|
|
stk = []
|
|
for c in tensor_str: # '[' ',' ']' '.' '-' or value
|
|
if c == ",":
|
|
ext_val = extract_val()
|
|
if ext_val is not None: stk.append(ext_val)
|
|
elif c == "]":
|
|
ext_val = extract_val()
|
|
if ext_val is not None: stk.append(ext_val)
|
|
|
|
arr = []
|
|
while stk[-1] != '[':
|
|
arr.append(stk.pop())
|
|
stk.pop() # the left [
|
|
|
|
arr.reverse()
|
|
stk.append(arr)
|
|
else:
|
|
stk.append(c)
|
|
val = stk[0]
|
|
|
|
# wrap with numpy with the specific data type
|
|
if tensor_type == "int64":
|
|
return np.array(val, dtype=np.int64)
|
|
elif tensor_type == "int32":
|
|
return np.array(val, dtype=np.int32)
|
|
elif tensor_type == "int8":
|
|
return np.array(val, dtype=np.int8)
|
|
elif tensor_type == "float64":
|
|
return np.array(val, dtype=np.float64)
|
|
elif tensor_type == "float32":
|
|
return np.array(val, dtype=np.float32)
|
|
else:
|
|
raise RuntimeError("type {} is not considered in current version. \
|
|
You can kindly report an issue for this problem. Thanks!".format(tensor_type))
|
|
|
|
# map np datatype to onnx datatype
|
|
# https://github.com/onnx/onnx/blob/8669fad0247799f4d8683550eec749974b4f5338/onnx/helper.py#L1177
|
|
def np2onnxdtype(np_dtype):
|
|
return cast(int, NP_TYPE_TO_TENSOR_TYPE[np_dtype])
|
|
|
|
if __name__ == "__main__":
|
|
# tensor_str = "1"
|
|
# tensor_str = "[1, 2, 3]"
|
|
tensor_str = "[[10, 2.3, 3],[1, 2e6, 3]]"
|
|
val = parse_tensor(tensor_str, "float32")
|
|
print(type(val), val)
|
|
|
|
tensor_str = "[[10, 2, 3],[1, 2, 3]]"
|
|
val = parse_tensor(tensor_str, "int64")
|
|
print(type(val), val)
|
|
|
|
|