From 5cbe8e5ecaee7f33d25353fc0827a7895b16d06a Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 7 Nov 2024 12:23:19 +0100 Subject: [PATCH 1/8] Improves onnx_simple_text_plot --- onnx_array_api/plotting/text_plot.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index d3f27c6..f2ccf3a 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -824,7 +824,10 @@ def str_node(indent, node): rows.append(f"opset: domain={opset.domain!r} version={opset.version!r}") if hasattr(model, "graph"): if model.doc_string: - rows.append(f"doc_string: {model.doc_string}") + if len(model.doc_string) < 55: + rows.append(f"doc_string: {model.doc_string}") + else: + rows.append(f"doc_string: {model.doc_string[:55]}...") main_model = model model = model.graph else: @@ -1044,7 +1047,10 @@ def _mark_link(rows, lengths, r1, r2, d): for fct in main_model.functions: rows.append(f"----- function name={fct.name} domain={fct.domain}") if fct.doc_string: - rows.append(f"----- doc_string: {fct.doc_string}") + if len(fct.doc_string) < 55: + rows.append(f"----- doc_string: {fct.doc_string}") + else: + rows.append(f"----- doc_string: {fct.doc_string[:55]}...") res = onnx_simple_text_plot( fct, verbose=verbose, From ed4a4b692c5b195410853f8560aa97c8b5c27e38 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 12 Nov 2024 12:16:37 +0100 Subject: [PATCH 2/8] add doc_string --- onnx_array_api/plotting/text_plot.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index f2ccf3a..662ee1b 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -864,9 +864,15 @@ def str_node(indent, node): else: content = "" line_name_new[init.name] = len(rows) + if init.doc_string: + rows.append( + f"init: name={init.name!r} type={_get_type(init)} " + f"shape={_get_shape(init)}{content} -- {init.doc_string}" + ) + continue rows.append( - "init: name=%r type=%r shape=%r%s" - % (init.name, _get_type(init), _get_shape(init), content) + f"init: name={init.name!r} type={_get_type(init)} " + f"shape={_get_shape(init)}{content}" ) if level == 0: rows.append("----- main graph ----") @@ -1109,10 +1115,18 @@ def onnx_text_plot_io(model, verbose=False, att_display=None): ) # initializer for init in model.initializer: + + if init.doc_string: + rows.append( + f"init: name={init.name!r} type={_get_type(init)} " + f"shape={_get_shape(init)} -- {init.doc_string}" + ) + continue rows.append( - "init: name=%r type=%r shape=%r" - % (init.name, _get_type(init), _get_shape(init)) + f"init: name={init.name!r} type={_get_type(init)} " + f"shape={_get_shape(init)}" ) + # outputs for out in model.output: rows.append( From df4f45e20cd09f2f99de2424bf0fc908381189a5 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 12 Nov 2024 14:46:25 +0100 Subject: [PATCH 3/8] improve display --- onnx_array_api/array_api/_onnx_common.py | 2 ++ onnx_array_api/plotting/text_plot.py | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index abc59a9..d69084a 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -93,6 +93,8 @@ def asarray( v = TEagerTensor(va) elif isinstance(a, float): v = TEagerTensor(np.array(a, dtype=np.float64)) + elif isinstance(a, complex): + v = TEagerTensor(np.array(a, dtype=np.complex128)) elif isinstance(a, bool): v = TEagerTensor(np.array(a, dtype=np.bool_)) elif isinstance(a, str): diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index 662ee1b..0b4d30a 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -865,10 +865,11 @@ def str_node(indent, node): content = "" line_name_new[init.name] = len(rows) if init.doc_string: - rows.append( + t = ( f"init: name={init.name!r} type={_get_type(init)} " - f"shape={_get_shape(init)}{content} -- {init.doc_string}" + f"shape={_get_shape(init)}{content}" ) + rows.append(f"{t}{' ' * max(0, 70 - len(t))}-- {init.doc_string}") continue rows.append( f"init: name={init.name!r} type={_get_type(init)} " @@ -1117,10 +1118,11 @@ def onnx_text_plot_io(model, verbose=False, att_display=None): for init in model.initializer: if init.doc_string: - rows.append( + t = ( f"init: name={init.name!r} type={_get_type(init)} " - f"shape={_get_shape(init)} -- {init.doc_string}" + f"shape={_get_shape(init)}" ) + rows.append(f"{t}{' ' * max(0, 70 - len(t))}-- {init.doc_string}") continue rows.append( f"init: name={init.name!r} type={_get_type(init)} " From 7b658227d6a2ea89d566e1d871daa33c98323e28 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 12 Nov 2024 15:55:20 +0100 Subject: [PATCH 4/8] add complex --- onnx_array_api/_helpers.py | 4 ++++ onnx_array_api/annotations.py | 2 ++ onnx_array_api/array_api/__init__.py | 4 ++++ onnx_array_api/npx/npx_var.py | 4 ++++ onnx_array_api/reference/evaluator_yield.py | 6 ++++++ onnx_array_api/reference/ops/op_constant_of_shape.py | 4 ++++ 6 files changed, 24 insertions(+) diff --git a/onnx_array_api/_helpers.py b/onnx_array_api/_helpers.py index 1d95bb2..9331098 100644 --- a/onnx_array_api/_helpers.py +++ b/onnx_array_api/_helpers.py @@ -40,6 +40,10 @@ def np_dtype_to_tensor_dtype(dtype: Any): dt = TensorProto.INT64 elif dtype is float: dt = TensorProto.DOUBLE + elif dtype == np.complex64: + dt = TensorProto.COMPLEX64 + elif dtype == np.complex128: + dt = TensorProto.COMPLEX128 else: raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904 return dt diff --git a/onnx_array_api/annotations.py b/onnx_array_api/annotations.py index 9941f95..c29102c 100644 --- a/onnx_array_api/annotations.py +++ b/onnx_array_api/annotations.py @@ -64,6 +64,8 @@ def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: np.uint64: TensorProto.UINT64, np.bool_: TensorProto.BOOL, np.str_: TensorProto.STRING, + np.complex64: TensorProto.COMPLEX64, + np.complex128: TensorProto.COMPLEX128, } diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index 3252405..9b67b4b 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -47,6 +47,8 @@ def _finfo(dtype): continue if isinstance(v, (np.float32, np.float64, np.float16)): d[k] = float(v) + elif isinstance(v, (np.complex128, np.complex64)): + d[k] = complex(v) else: d[k] = v d["dtype"] = DType(np_dtype_to_tensor_dtype(dt)) @@ -124,6 +126,8 @@ def _finalize_array_api(module, function_names, TEagerTensor): module.float16 = DType(TensorProto.FLOAT16) module.float32 = DType(TensorProto.FLOAT) module.float64 = DType(TensorProto.DOUBLE) + module.complex64 = DType(TensorProto.COMPLEX64) + module.complex128 = DType(TensorProto.COMPLEX128) module.int8 = DType(TensorProto.INT8) module.int16 = DType(TensorProto.INT16) module.int32 = DType(TensorProto.INT32) diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 169183c..0e71070 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -1171,6 +1171,8 @@ def __init__(self, cst: Any): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif isinstance(cst, float): Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") + elif isinstance(cst, complex): + Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity") elif isinstance(cst, list): if all(isinstance(t, bool) for t in cst): Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") @@ -1178,6 +1180,8 @@ def __init__(self, cst: Any): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif all(isinstance(t, (float, int, bool)) for t in cst): Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") + elif all(isinstance(t, (float, int, bool, complex)) for t in cst): + Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity") else: raise ValueError( f"Unable to convert cst (type={type(cst)}), value={cst}." diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 9c3b6ec..5b77e8b 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -485,6 +485,12 @@ def generate_input(info: ValueInfoProto) -> np.ndarray: return (value.astype(np.float16) / p).astype(np.float16).reshape(new_shape) if elem_type == TensorProto.DOUBLE: return (value.astype(np.float64) / p).astype(np.float64).reshape(new_shape) + if elem_type == TensorProto.COMPLEX64: + return (value.astype(np.complex64) / p).astype(np.complex64).reshape(new_shape) + if elem_type == TensorProto.COMPLEX128: + return ( + (value.astype(np.complex128) / p).astype(np.complex128).reshape(new_shape) + ) raise RuntimeError(f"Unexpected element_type {elem_type} for info={info}") diff --git a/onnx_array_api/reference/ops/op_constant_of_shape.py b/onnx_array_api/reference/ops/op_constant_of_shape.py index 00c6989..a54bb5a 100644 --- a/onnx_array_api/reference/ops/op_constant_of_shape.py +++ b/onnx_array_api/reference/ops/op_constant_of_shape.py @@ -19,6 +19,8 @@ def _process(value): cst = np.int64(cst) elif isinstance(cst, float): cst = np.float64(cst) + elif isinstance(cst, complex): + cst = np.complex128(cst) elif cst is None: cst = np.float32(0) if not isinstance( @@ -27,6 +29,8 @@ def _process(value): np.float16, np.float32, np.float64, + np.complex64, + np.complex128, np.int64, np.int32, np.int16, From e8c3c420ebe1075d444748ce777c34af7b370255 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 12 Nov 2024 16:19:27 +0100 Subject: [PATCH 5/8] add missing line --- onnx_array_api/npx/npx_numpy_tensors.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index 1daef44..9579455 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -265,6 +265,8 @@ def __float__(self): DType(TensorProto.DOUBLE), DType(TensorProto.FLOAT16), DType(TensorProto.BFLOAT16), + DType(TensorProto.COMPLEX64), + DType(TensorProto.COMPLEX128), }: raise TypeError( f"Conversion to float only works for float scalar, " @@ -272,6 +274,26 @@ def __float__(self): ) return float(self._tensor) + def __complex__(self): + "Implicit conversion to complex." + if self.shape: + raise ValueError( + f"Conversion to bool only works for scalar, not for {self!r}." + ) + if self.dtype not in { + DType(TensorProto.FLOAT), + DType(TensorProto.DOUBLE), + DType(TensorProto.FLOAT16), + DType(TensorProto.BFLOAT16), + DType(TensorProto.COMPLEX64), + DType(TensorProto.COMPLEX128), + }: + raise TypeError( + f"Conversion to float only works for float scalar, " + f"not for dtype={self.dtype}." + ) + return complex(self._tensor) + def __iter__(self): """ The :epkg:`Array API` does not define this function (2022/12). From 6527538c0236b4f9a29d3b4e835a1d47e6d2467e Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 12 Nov 2024 16:30:23 +0100 Subject: [PATCH 6/8] complex --- onnx_array_api/npx/npx_jit_eager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 172bb86..e33d941 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -167,7 +167,7 @@ def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, . f"to the attribute list, v={v}." ) res.append(v.key) - elif isinstance(v, (int, float, bool, DType)): + elif isinstance(v, (int, float, bool, complex, DType)): if iv in self.kwargs_to_input_: res.append(self.kwargs_to_input_[iv]) res.append(type(v)) From 685d073704c169ce935a3ceaeffac5d28c9f031f Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 12 Nov 2024 17:38:54 +0100 Subject: [PATCH 7/8] complex --- onnx_array_api/graph_api/graph_builder.py | 2 ++ onnx_array_api/npx/npx_jit_eager.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 293d2cc..c176d4e 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -536,6 +536,8 @@ def make_nodes( if isinstance(value, TensorProto): value.name = name self.initializers_dict[name] = value + if name in builder.initializers_dict_sources: + self.initializers_dict_sources[name] = builder.initializers_dict_sources[name] self.constants_[name] = None self.set_shape(name, builder._known_shapes[init]) diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index e33d941..267eda5 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -204,7 +204,7 @@ def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, . if k in self.kwargs_to_input_: res.append(type(v)) res.append(v) - elif isinstance(v, (int, float, str, type, bool, DType)): + elif isinstance(v, (int, float, str, type, bool, complex, DType)): res.append(k) res.append(type(v)) res.append(v) From d4832be2a68983e9a6f47cd4455ad257d8849452 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 12 Nov 2024 18:24:41 +0100 Subject: [PATCH 8/8] fix unwanted code --- onnx_array_api/graph_api/graph_builder.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index c176d4e..558c34a 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -536,9 +536,6 @@ def make_nodes( if isinstance(value, TensorProto): value.name = name self.initializers_dict[name] = value - if name in builder.initializers_dict_sources: - self.initializers_dict_sources[name] = builder.initializers_dict_sources[name] - self.constants_[name] = None self.set_shape(name, builder._known_shapes[init]) self.set_type(name, builder._known_types[init])