diff --git a/onnx_array_api/_helpers.py b/onnx_array_api/_helpers.py index 1d95bb2..9331098 100644 --- a/onnx_array_api/_helpers.py +++ b/onnx_array_api/_helpers.py @@ -40,6 +40,10 @@ def np_dtype_to_tensor_dtype(dtype: Any): dt = TensorProto.INT64 elif dtype is float: dt = TensorProto.DOUBLE + elif dtype == np.complex64: + dt = TensorProto.COMPLEX64 + elif dtype == np.complex128: + dt = TensorProto.COMPLEX128 else: raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904 return dt diff --git a/onnx_array_api/annotations.py b/onnx_array_api/annotations.py index 9941f95..c29102c 100644 --- a/onnx_array_api/annotations.py +++ b/onnx_array_api/annotations.py @@ -64,6 +64,8 @@ def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: np.uint64: TensorProto.UINT64, np.bool_: TensorProto.BOOL, np.str_: TensorProto.STRING, + np.complex64: TensorProto.COMPLEX64, + np.complex128: TensorProto.COMPLEX128, } diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index 3252405..9b67b4b 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -47,6 +47,8 @@ def _finfo(dtype): continue if isinstance(v, (np.float32, np.float64, np.float16)): d[k] = float(v) + elif isinstance(v, (np.complex128, np.complex64)): + d[k] = complex(v) else: d[k] = v d["dtype"] = DType(np_dtype_to_tensor_dtype(dt)) @@ -124,6 +126,8 @@ def _finalize_array_api(module, function_names, TEagerTensor): module.float16 = DType(TensorProto.FLOAT16) module.float32 = DType(TensorProto.FLOAT) module.float64 = DType(TensorProto.DOUBLE) + module.complex64 = DType(TensorProto.COMPLEX64) + module.complex128 = DType(TensorProto.COMPLEX128) module.int8 = DType(TensorProto.INT8) module.int16 = DType(TensorProto.INT16) module.int32 = DType(TensorProto.INT32) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index abc59a9..d69084a 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -93,6 +93,8 @@ def asarray( v = TEagerTensor(va) elif isinstance(a, float): v = TEagerTensor(np.array(a, dtype=np.float64)) + elif isinstance(a, complex): + v = TEagerTensor(np.array(a, dtype=np.complex128)) elif isinstance(a, bool): v = TEagerTensor(np.array(a, dtype=np.bool_)) elif isinstance(a, str): diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 293d2cc..558c34a 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -536,7 +536,6 @@ def make_nodes( if isinstance(value, TensorProto): value.name = name self.initializers_dict[name] = value - self.constants_[name] = None self.set_shape(name, builder._known_shapes[init]) self.set_type(name, builder._known_types[init]) diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 172bb86..267eda5 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -167,7 +167,7 @@ def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, . f"to the attribute list, v={v}." ) res.append(v.key) - elif isinstance(v, (int, float, bool, DType)): + elif isinstance(v, (int, float, bool, complex, DType)): if iv in self.kwargs_to_input_: res.append(self.kwargs_to_input_[iv]) res.append(type(v)) @@ -204,7 +204,7 @@ def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, . if k in self.kwargs_to_input_: res.append(type(v)) res.append(v) - elif isinstance(v, (int, float, str, type, bool, DType)): + elif isinstance(v, (int, float, str, type, bool, complex, DType)): res.append(k) res.append(type(v)) res.append(v) diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index 1daef44..9579455 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -265,6 +265,8 @@ def __float__(self): DType(TensorProto.DOUBLE), DType(TensorProto.FLOAT16), DType(TensorProto.BFLOAT16), + DType(TensorProto.COMPLEX64), + DType(TensorProto.COMPLEX128), }: raise TypeError( f"Conversion to float only works for float scalar, " @@ -272,6 +274,26 @@ def __float__(self): ) return float(self._tensor) + def __complex__(self): + "Implicit conversion to complex." + if self.shape: + raise ValueError( + f"Conversion to bool only works for scalar, not for {self!r}." + ) + if self.dtype not in { + DType(TensorProto.FLOAT), + DType(TensorProto.DOUBLE), + DType(TensorProto.FLOAT16), + DType(TensorProto.BFLOAT16), + DType(TensorProto.COMPLEX64), + DType(TensorProto.COMPLEX128), + }: + raise TypeError( + f"Conversion to float only works for float scalar, " + f"not for dtype={self.dtype}." + ) + return complex(self._tensor) + def __iter__(self): """ The :epkg:`Array API` does not define this function (2022/12). diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 169183c..0e71070 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -1171,6 +1171,8 @@ def __init__(self, cst: Any): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif isinstance(cst, float): Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") + elif isinstance(cst, complex): + Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity") elif isinstance(cst, list): if all(isinstance(t, bool) for t in cst): Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") @@ -1178,6 +1180,8 @@ def __init__(self, cst: Any): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif all(isinstance(t, (float, int, bool)) for t in cst): Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") + elif all(isinstance(t, (float, int, bool, complex)) for t in cst): + Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity") else: raise ValueError( f"Unable to convert cst (type={type(cst)}), value={cst}." diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index d3f27c6..0b4d30a 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -824,7 +824,10 @@ def str_node(indent, node): rows.append(f"opset: domain={opset.domain!r} version={opset.version!r}") if hasattr(model, "graph"): if model.doc_string: - rows.append(f"doc_string: {model.doc_string}") + if len(model.doc_string) < 55: + rows.append(f"doc_string: {model.doc_string}") + else: + rows.append(f"doc_string: {model.doc_string[:55]}...") main_model = model model = model.graph else: @@ -861,9 +864,16 @@ def str_node(indent, node): else: content = "" line_name_new[init.name] = len(rows) + if init.doc_string: + t = ( + f"init: name={init.name!r} type={_get_type(init)} " + f"shape={_get_shape(init)}{content}" + ) + rows.append(f"{t}{' ' * max(0, 70 - len(t))}-- {init.doc_string}") + continue rows.append( - "init: name=%r type=%r shape=%r%s" - % (init.name, _get_type(init), _get_shape(init), content) + f"init: name={init.name!r} type={_get_type(init)} " + f"shape={_get_shape(init)}{content}" ) if level == 0: rows.append("----- main graph ----") @@ -1044,7 +1054,10 @@ def _mark_link(rows, lengths, r1, r2, d): for fct in main_model.functions: rows.append(f"----- function name={fct.name} domain={fct.domain}") if fct.doc_string: - rows.append(f"----- doc_string: {fct.doc_string}") + if len(fct.doc_string) < 55: + rows.append(f"----- doc_string: {fct.doc_string}") + else: + rows.append(f"----- doc_string: {fct.doc_string[:55]}...") res = onnx_simple_text_plot( fct, verbose=verbose, @@ -1103,10 +1116,19 @@ def onnx_text_plot_io(model, verbose=False, att_display=None): ) # initializer for init in model.initializer: + + if init.doc_string: + t = ( + f"init: name={init.name!r} type={_get_type(init)} " + f"shape={_get_shape(init)}" + ) + rows.append(f"{t}{' ' * max(0, 70 - len(t))}-- {init.doc_string}") + continue rows.append( - "init: name=%r type=%r shape=%r" - % (init.name, _get_type(init), _get_shape(init)) + f"init: name={init.name!r} type={_get_type(init)} " + f"shape={_get_shape(init)}" ) + # outputs for out in model.output: rows.append( diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 9c3b6ec..5b77e8b 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -485,6 +485,12 @@ def generate_input(info: ValueInfoProto) -> np.ndarray: return (value.astype(np.float16) / p).astype(np.float16).reshape(new_shape) if elem_type == TensorProto.DOUBLE: return (value.astype(np.float64) / p).astype(np.float64).reshape(new_shape) + if elem_type == TensorProto.COMPLEX64: + return (value.astype(np.complex64) / p).astype(np.complex64).reshape(new_shape) + if elem_type == TensorProto.COMPLEX128: + return ( + (value.astype(np.complex128) / p).astype(np.complex128).reshape(new_shape) + ) raise RuntimeError(f"Unexpected element_type {elem_type} for info={info}") diff --git a/onnx_array_api/reference/ops/op_constant_of_shape.py b/onnx_array_api/reference/ops/op_constant_of_shape.py index 00c6989..a54bb5a 100644 --- a/onnx_array_api/reference/ops/op_constant_of_shape.py +++ b/onnx_array_api/reference/ops/op_constant_of_shape.py @@ -19,6 +19,8 @@ def _process(value): cst = np.int64(cst) elif isinstance(cst, float): cst = np.float64(cst) + elif isinstance(cst, complex): + cst = np.complex128(cst) elif cst is None: cst = np.float32(0) if not isinstance( @@ -27,6 +29,8 @@ def _process(value): np.float16, np.float32, np.float64, + np.complex64, + np.complex128, np.int64, np.int32, np.int16,