diff --git a/.gitignore b/.gitignore index 6774a18..c51b919 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ _doc/_static/viz.js _unittests/ut__main/*.png _unittests/ut__main/_cache/* _unittests/ut__main/*.html +_unittests/.hypothesis/* diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 75c0aa4..475fad6 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -8,10 +8,10 @@ API array_api npx_functions - npx_var npx_jit - npx_annot npx_numpy + npx_types + npx_var onnx_tools ort plotting diff --git a/_doc/api/npx_annot.rst b/_doc/api/npx_types.rst similarity index 82% rename from _doc/api/npx_annot.rst rename to _doc/api/npx_types.rst index 43de2d7..dc1a378 100644 --- a/_doc/api/npx_annot.rst +++ b/_doc/api/npx_types.rst @@ -1,38 +1,40 @@ -============= npx.npx_types ============= DType -===== ++++++ .. autoclass:: onnx_array_api.npx.npx_types.DType :members: -Annotations -=========== - ElemType ++++++++ .. autoclass:: onnx_array_api.npx.npx_types.ElemType :members: -ParType -+++++++ - -.. autoclass:: onnx_array_api.npx.npx_types.ParType - :members: - OptParType ++++++++++ .. autoclass:: onnx_array_api.npx.npx_types.OptParType :members: -TensorType -++++++++++ +OptTensorType ++++++++++++++ -.. autoclass:: onnx_array_api.npx.npx_types.TensorType +.. autoclass:: onnx_array_api.npx.npx_types.OptTensorType + :members: + +ParType ++++++++ + +.. autoclass:: onnx_array_api.npx.npx_types.ParType + :members: + +Scalar +++++++ + +.. autoclass:: onnx_array_api.npx.npx_types.Scalar :members: SequenceType @@ -41,6 +43,18 @@ SequenceType .. autoclass:: onnx_array_api.npx.npx_types.SequenceType :members: +ShapeType ++++++++++ + +.. autoclass:: onnx_array_api.npx.npx_types.ShapeType + :members: + +TensorType +++++++++++ + +.. autoclass:: onnx_array_api.npx.npx_types.TensorType + :members: + TupleType +++++++++ diff --git a/_doc/api/npx_var.rst b/_doc/api/npx_var.rst index 8041e5e..1f863fb 100644 --- a/_doc/api/npx_var.rst +++ b/_doc/api/npx_var.rst @@ -15,3 +15,16 @@ Cst, Input .. autoclass:: onnx_array_api.npx.npx_var.Input :members: + +ManyIdentity +++++++++++++ + +.. autoclass:: onnx_array_api.npx.npx_var.ManyIdentity + :members: + +Par ++++ + +.. autoclass:: onnx_array_api.npx.npx_var.Par + :members: + diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index 62de43f..dcb067c 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -1,6 +1,7 @@ # API failures # see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt -array_api_tests/test_creation_functions.py::test_arange +array_api_tests/test_creation_functions.py::test_asarray_scalars +# array_api_tests/test_creation_functions.py::test_arange array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index 9464ee6..089aa3b 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars || exit 1 +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_arange || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help -pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 \ No newline at end of file +pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 4cb7544..23852c7 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -1,3 +1,4 @@ +import sys import unittest import numpy as np from onnx_array_api.ext_test_case import ExtTestCase @@ -19,6 +20,22 @@ def test_zeros(self): a = xp.absolute(mat) self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + def test_arange_default(self): + a = EagerTensor(np.array([0], dtype=np.int64)) + b = EagerTensor(np.array([2], dtype=np.int64)) + mat = xp.arange(a, b) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (2,)) + self.assertEqualArray(matnp, np.arange(0, 2).astype(np.int64)) + + def test_arange_step(self): + a = EagerTensor(np.array([4], dtype=np.int64)) + s = EagerTensor(np.array([2], dtype=np.int64)) + mat = xp.arange(a, step=s) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (2,)) + self.assertEqualArray(matnp, np.arange(4, step=2).astype(np.int64)) + def test_zeros_none(self): c = EagerTensor(np.array([4, 5], dtype=np.int64)) mat = xp.zeros(c) @@ -52,7 +69,27 @@ def test_full_bool(self): self.assertNotEmpty(matnp[0, 0]) self.assertEqualArray(matnp, np.full((4, 5), False)) + def test_arange_int00a(self): + a = EagerTensor(np.array([0], dtype=np.int64)) + b = EagerTensor(np.array([0], dtype=np.int64)) + mat = xp.arange(a, b) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (0,)) + expected = np.arange(0, 0) + if sys.platform == "win32": + expected = expected.astype(np.int64) + self.assertEqualArray(matnp, expected) + + def test_arange_int00(self): + mat = xp.arange(0, 0) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (0,)) + expected = np.arange(0, 0) + if sys.platform == "win32": + expected = expected.astype(np.int64) + self.assertEqualArray(matnp, expected) + if __name__ == "__main__": - TestOnnxNumpy().test_zeros_none() + TestOnnxNumpy().test_arange_int00() unittest.main(verbosity=2) diff --git a/_unittests/ut_npx/test_npx.py b/_unittests/ut_npx/test_npx.py index 17b5863..7a5b33a 100644 --- a/_unittests/ut_npx/test_npx.py +++ b/_unittests/ut_npx/test_npx.py @@ -103,6 +103,7 @@ Int64, OptParType, TensorType, + OptTensorType, ) from onnx_array_api.npx.npx_var import Input, Var @@ -125,35 +126,62 @@ def test_shape_inference(self): self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT) def test_tensor(self): - dt = TensorType["float32"] + dt = TensorType["float32", "F32"] self.assertEqual(len(dt.dtypes), 1) self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) self.assertEmpty(dt.shape) - self.assertEqual(dt.type_name(), "TensorType['float32']") + self.assertEqual(dt.type_name(), "TensorType['float32', 'F32']") - dt = TensorType["float32"] + dt = TensorType["float32", "F32"] self.assertEqual(len(dt.dtypes), 1) self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) - self.assertEqual(dt.type_name(), "TensorType['float32']") + self.assertEqual(dt.type_name(), "TensorType['float32', 'F32']") - dt = TensorType[np.float32] + dt = TensorType[np.float32, "F32"] self.assertEqual(len(dt.dtypes), 1) self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) - self.assertEqual(dt.type_name(), "TensorType['float32']") + self.assertEqual(dt.type_name(), "TensorType['float32', 'F32']") self.assertEmpty(dt.shape) - dt = TensorType[np.str_] + dt = TensorType[np.str_, "TEXT"] self.assertEqual(len(dt.dtypes), 1) self.assertEqual(dt.dtypes[0].dtype, ElemType.str_) - self.assertEqual(dt.type_name(), "TensorType[strings]") + self.assertEqual(dt.type_name(), "TensorType[strings, 'TEXT']") + self.assertEmpty(dt.shape) + + self.assertRaise(lambda: TensorType[None], TypeError) + self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError) + + def test_opt_tensor(self): + dt = OptTensorType["float32", "F32"] + self.assertEqual(len(dt.dtypes), 1) + self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) + self.assertEmpty(dt.shape) + self.assertEqual(dt.type_name(), "OptTensorType['float32', 'F32']") + + dt = OptTensorType["float32", "F32"] + self.assertEqual(len(dt.dtypes), 1) + self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) + self.assertEqual(dt.type_name(), "OptTensorType['float32', 'F32']") + + dt = OptTensorType[np.float32, "F32"] + self.assertEqual(len(dt.dtypes), 1) + self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) + self.assertEqual(dt.type_name(), "OptTensorType['float32', 'F32']") + self.assertEmpty(dt.shape) + + dt = OptTensorType[np.str_, "TEXT"] + self.assertEqual(len(dt.dtypes), 1) + self.assertEqual(dt.dtypes[0].dtype, ElemType.str_) + self.assertEqual(dt.type_name(), "OptTensorType[strings, 'TEXT']") self.assertEmpty(dt.shape) self.assertRaise(lambda: TensorType[None], TypeError) self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError) def test_superset(self): - t1 = TensorType[ElemType.numerics] - t2 = TensorType[ElemType.float64] + t1 = TensorType[ElemType.numerics, "T"] + t2 = TensorType[ElemType.float64, "F64"] self.assertTrue(t1.issuperset(t2)) t1 = Float32[None] t2 = Float32[None] @@ -167,14 +195,14 @@ def test_superset(self): t1 = Float32["N"] t2 = Float32[5] self.assertTrue(t1.issuperset(t2)) - t1 = TensorType[ElemType.int64] + t1 = TensorType[ElemType.int64, "I"] t2 = Int64[1] self.assertTrue(t1.issuperset(t2)) def test_sig(self): def local1( - x: TensorType[ElemType.floats], - ) -> TensorType[ElemType.floats]: + x: TensorType[ElemType.floats, "T"], + ) -> TensorType[ElemType.floats, "T"]: return x def local2( @@ -2536,13 +2564,17 @@ def test_numpy_all_empty_axis_1(self): got = ref.run(None, {"A": data}) self.assertEqualArray(y, got[0]) - @unittest.skipIf(True, reason="Fails to follow Array API") - def test_get_item(self): + def test_get_item_b(self): a = EagerNumpyTensor(np.array([True], dtype=np.bool_)) i = a[0] self.assertEqualArray(i.numpy(), a.numpy()[0]) + def test_get_item_i8(self): + a = EagerNumpyTensor(np.array([5, 6], dtype=np.int8)) + i = a[0] + self.assertEqualArray(i.numpy(), a.numpy()[0]) + if __name__ == "__main__": - # TestNpx().test_get_item() + TestNpx().test_filter() unittest.main(verbosity=2) diff --git a/_unittests/ut__main/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py similarity index 100% rename from _unittests/ut__main/test_documentation_examples.py rename to _unittests/ut_xrun_doc/test_documentation_examples.py diff --git a/_unittests/ut__main/test_profiling.py b/_unittests/ut_xrun_doc/test_profiling.py similarity index 100% rename from _unittests/ut__main/test_profiling.py rename to _unittests/ut_xrun_doc/test_profiling.py diff --git a/_unittests/win_test_array_api.bat b/_unittests/win_test_array_api.bat new file mode 100644 index 0000000..1ec2833 --- /dev/null +++ b/_unittests/win_test_array_api.bat @@ -0,0 +1,4 @@ +@echo off +set ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy +python -m pytest ../../array-api-tests/array_api_tests/test_creation_functions.py::test_arange || exit 1 +python -m pytest ../../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 425418f..9aab6f8 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -3,6 +3,7 @@ """ from typing import Any, Optional import numpy as np +from onnx import TensorProto from ..npx.npx_functions import ( all, abs, @@ -15,11 +16,20 @@ reshape, take, ) +from ..npx.npx_functions import arange as generic_arange from ..npx.npx_functions import full as generic_full from ..npx.npx_functions import ones as generic_ones from ..npx.npx_functions import zeros as generic_zeros from ..npx.npx_numpy_tensors import EagerNumpyTensor -from ..npx.npx_types import DType, ElemType, TensorType, OptParType, ParType, Scalar +from ..npx.npx_types import ( + DType, + ElemType, + TensorType, + OptParType, + OptTensorType, + ParType, + Scalar, +) from ._onnx_common import template_asarray from . import _finalize_array_api @@ -27,6 +37,7 @@ "abs", "absolute", "all", + "arange", "asarray", "astype", "empty", @@ -57,6 +68,44 @@ def asarray( ) +def arange( + start_or_stop: TensorType[ElemType.int64, "I", (1,)], + stop_or_step: OptTensorType[ElemType.int64, "I", (1,)] = None, + step: OptTensorType[ElemType.int64, "I", (1,)] = None, + dtype: OptParType[DType] = None, +) -> TensorType[ElemType.numerics, "T"]: + use_float = any( + map(lambda x: isinstance(x, float), [start_or_stop, stop_or_step, step]) + ) + if isinstance(start_or_stop, int): + start_or_stop = EagerNumpyTensor( + np.array([start_or_stop], dtype=np.float64 if use_float else np.int64) + ) + elif isinstance(start_or_stop, float): + start_or_stop = EagerNumpyTensor(np.array([start_or_stop], dtype=np.float64)) + assert use_float + + if isinstance(stop_or_step, int): + stop_or_step = EagerNumpyTensor( + np.array([stop_or_step], dtype=np.float64 if use_float else np.int64) + ) + elif isinstance(stop_or_step, float): + stop_or_step = EagerNumpyTensor(np.array([stop_or_step], dtype=np.float64)) + assert use_float + + if isinstance(step, int): + step = EagerNumpyTensor( + np.array([step], dtype=np.float64 if use_float else np.int64) + ) + elif isinstance(step, float): + step = EagerNumpyTensor(np.array([step], dtype=np.float64)) + assert use_float + + if dtype is None and use_float: + dtype = DType(TensorProto.DOUBLE) + return generic_arange(start_or_stop, stop_or_step, step, dtype=dtype) + + def ones( shape: TensorType[ElemType.int64, "I", (None,)], dtype: OptParType[DType] = None, diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 98e37f4..27147c4 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Tuple, Union import array_api_compat.numpy as np_array_api import numpy as np from onnx import FunctionProto, ModelProto, NodeProto, TensorProto @@ -11,11 +11,12 @@ DType, ElemType, OptParType, + OptTensorType, ParType, + Scalar, SequenceType, TensorType, TupleType, - Scalar, ) from .npx_var import Var @@ -45,7 +46,7 @@ def absolute( @npxapi_inline def all( x: TensorType[ElemType.bool_, "T"], - axis: Optional[TensorType[ElemType.int64, "I"]] = None, + axis: OptTensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0, ) -> TensorType[ElemType.bool_, "T"]: """ @@ -76,20 +77,6 @@ def all( return var(red, cst(1), op="Equal") -@npxapi_inline -def arccos(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: - "See :func:`numpy.arccos`." - return var(x, op="Acos") - - -@npxapi_inline -def arccosh( - x: TensorType[ElemType.numerics, "T"] -) -> TensorType[ElemType.numerics, "T"]: - "See :func:`numpy.arccosh`." - return var(x, op="Acosh") - - @npxapi_inline def amax( x: TensorType[ElemType.numerics, "T"], @@ -116,10 +103,40 @@ def amin( @npxapi_inline def arange( - start_or_stop: TensorType[ElemType.int64, "I", (1,)], - stop_or_step: Optional[TensorType[ElemType.int64, "I", (1,)]] = None, - step: Optional[TensorType[ElemType.int64, "I", (1,)]] = None, - dtype=None, + start_or_stop: TensorType[ + { + ElemType.int16, + ElemType.int32, + ElemType.int64, + ElemType.float32, + ElemType.float64, + }, + "I", + (1,), + ], + stop_or_step: OptTensorType[ + { + ElemType.int16, + ElemType.int32, + ElemType.int64, + ElemType.float32, + ElemType.float64, + }, + "I", + (1,), + ] = None, + step: OptTensorType[ + { + ElemType.int16, + ElemType.int32, + ElemType.int64, + ElemType.float32, + ElemType.float64, + }, + "I", + (1,), + ] = None, + dtype: OptParType[DType] = None, ) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.arccos`." if stop_or_step is None: @@ -140,6 +157,20 @@ def arange( return v +@npxapi_inline +def arccos(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: + "See :func:`numpy.arccos`." + return var(x, op="Acos") + + +@npxapi_inline +def arccosh( + x: TensorType[ElemType.numerics, "T"] +) -> TensorType[ElemType.numerics, "T"]: + "See :func:`numpy.arccosh`." + return var(x, op="Acosh") + + @npxapi_inline def argmax( x: TensorType[ElemType.numerics, "T"], @@ -298,7 +329,7 @@ def cosh(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, @npxapi_inline def cumsum( x: TensorType[ElemType.numerics, "T"], - axis: Optional[TensorType[ElemType.int64, "I"]] = None, + axis: OptTensorType[ElemType.int64, "I"] = None, ) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.cumsum`." if axis is None: @@ -522,8 +553,8 @@ def ones( def pad( x: TensorType[ElemType.numerics, "T"], pads: TensorType[ElemType.int64, "I"], - constant_value: Optional[TensorType[ElemType.numerics, "T"]] = None, - axes: Optional[TensorType[ElemType.int64, "I"]] = None, + constant_value: OptTensorType[ElemType.numerics, "T"] = None, + axes: OptTensorType[ElemType.int64, "I"] = None, mode: ParType[str] = "constant", ): """ @@ -618,7 +649,7 @@ def sqrt(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, @npxapi_inline def squeeze( x: TensorType[ElemType.numerics, "T"], - axis: Optional[TensorType[ElemType.int64, "I"]] = None, + axis: OptTensorType[ElemType.int64, "I"] = None, ) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.squeeze`." if axis is None: diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index ff02843..396cf39 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -41,6 +41,7 @@ DType, ElemType, OptParType, + OptTensorType, ParType, SequenceType, TensorType, @@ -270,7 +271,7 @@ def make_node( self.nodes_.append(node) def _io( - self, index: int, name: str, tensor_type: Optional[type], is_input: bool + self, index: int, name: str, tensor_type: type, is_input: bool ) -> ValueInfoProto: """ Converts an input or output into :class:`onnx.ValueInfoProto`. @@ -284,7 +285,9 @@ def _io( """ if self.as_function: return _FunctionIO(name) - if tensor_type is not None and not issubclass(tensor_type, TensorType): + if tensor_type is not None and not issubclass( + tensor_type, (TensorType, OptTensorType) + ): raise TypeError( f"Unexpected type {tensor_type.type_name()} for tensor_type. " f"This may happen if you specialised the function based on " @@ -329,18 +332,44 @@ def _io( if is_input: raise RuntimeError( f"tensor_type cannot be None for name={name!r} and " - f"input or output {index}." + f"input or output {index!r}." ) - tensor_type = TensorType["undefined"] - if len(tensor_type.dtypes) != 1: + tensor_type = TensorType["undefined", "xxu"] + + dtype_code = None + if len(tensor_type.dtypes) == 1: + dtype_code = tensor_type.dtypes[0].dtype + else: + # Case when the constraints is too broad. + # We use the input type if available. + if index < len(self.inputs_): + use = self.inputs_[index] + else: + use = None + c_name = tensor_type.name + for i in range(len(self.inputs_)): + name = self.inputs_[i].name + if ( + name in self.constraints + and self.constraints[name].name == c_name + ): + use = self.inputs_[i] + if use is not None: + dtype_code = DType(use.type.tensor_type.elem_type) + + if dtype_code is None: raise RuntimeError( f"tensor_type is not specific enough ({str(tensor_type)} " - f"or its full representation {tensor_type!r})." + f"or its full representation {tensor_type!r}, " + f"is_input={is_input}, index={index}/{len(self.inputs_)}, " + f"self.constraints={self.constraints!r}, " + f"self.inputs_={self.inputs_})." ) + if tensor_type.shape is None: type_proto = TypeProto() tensor_type_proto = type_proto.tensor_type - tensor_type_proto.elem_type = tensor_type.dtypes[0].dtype.code + tensor_type_proto.elem_type = dtype_code.code value_info_proto = ValueInfoProto() value_info_proto.name = name # tensor_type_proto.shape.dim.extend([]) @@ -351,7 +380,7 @@ def _io( # with fixed rank. This can be changed here and in methods # `make_key`. shape = [None for _ in tensor_type.shape] - info = make_tensor_value_info(name, tensor_type.dtypes[0].dtype.code, shape) + info = make_tensor_value_info(name, dtype_code.code, shape) # check_value_info fails if the shape is left undefined check_value_info(info, self.check_context) return info @@ -494,7 +523,15 @@ def _function_to_onnx(self, fct: Callable, n_inputs: int, n_outputs: int): anno = par.annotation if not issubclass( anno, - (ElemType, OptParType, ParType, SequenceType, TensorType, TupleType), + ( + ElemType, + OptParType, + ParType, + SequenceType, + TensorType, + OptTensorType, + TupleType, + ), ): raise TypeError( f"Annotation must of a known not {type(anno)} for " diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index c222f01..58ffff6 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -5,7 +5,7 @@ import numpy as np from .npx_tensors import EagerTensor, JitTensor -from .npx_types import DType, TensorType +from .npx_types import DType, OptTensorType, TensorType from .npx_var import Cst, Input, Var logger = getLogger("onnx-array-api") @@ -47,6 +47,7 @@ def __init__( # onnx to remember an input in fact a mandatory parameter. self.n_inputs_ = 0 self.input_to_kwargs_ = None + self.kwargs_to_input_ = None self.method_name_ = None def info(self, prefix: Optional[str] = None, method_name: Optional[str] = None): @@ -57,13 +58,14 @@ def info(self, prefix: Optional[str] = None, method_name: Optional[str] = None): logger.info("") return logger.info( - "%s [%s.%s] nx=%d ni=%d kw=%d f=%s.%s cl=%s me=%s", + "%s [%s.%s] nx=%d ni=%d ikw=%d kwi=%d f=%s.%s cl=%s me=%s", prefix, self.__class__.__name__, method_name[:6], len(self.onxs), self.n_inputs_, 0 if self.input_to_kwargs_ is None else 1, + 0 if self.kwargs_to_input_ is None else 1, self.f.__module__, self.f.__name__, self.tensor_class.__name__, @@ -78,7 +80,8 @@ def status(self, me: str) -> str: f"[{self.__class__.__name__}.{me[:6]}]" f"nx={len(self.onxs)} " f"ni={self.n_inputs_} " - f"kw={0 if self.input_to_kwargs_ is None else 1} " + f"ikw={0 if self.input_to_kwargs_ is None else 1} " + f"kwi={0 if self.kwargs_to_input_ is None else 1} " f"f={self.f.__module__}.{self.f.__name__} " f"cl={self.tensor_class.__name__} " f"me={self.method_name_ or ''}" @@ -120,21 +123,36 @@ def get_onnx(self, key: Optional[int] = None): ) return self.onxs[key] - @staticmethod - def make_key(*values, **kwargs): + def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, ...]: """ Builds a key based on the input types and parameters. Every set of inputs or parameters producing the same key (or signature) must use the same compiled ONNX. + + :param values: values given to the function + :param kwargs: parameters + :return: tuple of mutable keys """ res = [] for iv, v in enumerate(values): if isinstance(v, (Var, EagerTensor, JitTensor)): + if iv in self.kwargs_to_input_: + raise RuntimeError( + f"Input {iv} should be a constant to be moved " + f"to the attribute list, v={v}." + ) res.append(v.key) elif isinstance(v, (int, float, bool, DType)): + if iv in self.kwargs_to_input_: + res.append(self.kwargs_to_input_[iv]) res.append(type(v)) res.append(v) elif isinstance(v, slice): + if iv in self.kwargs_to_input_: + raise NotImplementedError( + f"Input {iv} should be a constant to be moved " + f"to the attribute list, v={v}." + ) res.append(("slice", v.start, v.stop, v.step)) elif isinstance(v, type): res.append(("type", v.__name__)) @@ -148,13 +166,20 @@ def make_key(*values, **kwargs): else: raise TypeError(f"Input {iv} cannot have such tuple: {v}.") res.append(tuple(subkey)) + elif v is None: + if iv in self.kwargs_to_input_: + res.append(self.kwargs_to_input_[iv]) + res.append(None) else: raise TypeError( f"Unable to build a key, input {iv} has type {type(v)}." ) if kwargs: for k, v in sorted(kwargs.items()): - if isinstance(v, (int, float, str, type, bool, DType)): + if k in self.kwargs_to_input_: + res.append(type(v)) + res.append(v) + elif isinstance(v, (int, float, str, type, bool, DType)): res.append(k) res.append(type(v)) res.append(v) @@ -170,9 +195,9 @@ def make_key(*values, **kwargs): else: newv.append(t) res.append(tuple(newv)) - elif v is None and k in {"dtype"}: - res.append(k) - res.append(v) + elif v is None: + # optional parameter or inputs + pass else: raise TypeError( f"Type {type(v)} is not yet supported, " @@ -193,6 +218,7 @@ def to_jit(self, *values, **kwargs): annotations = self.f.__annotations__ if len(annotations) > 0: input_to_kwargs = {} + kwargs_to_input = {} names = list(annotations.keys()) annot_values = list(annotations.values()) constraints = {} @@ -200,28 +226,47 @@ def to_jit(self, *values, **kwargs): for i, (v, iname) in enumerate(zip(values, names)): if i < len(annot_values) and not isinstance(annot_values[i], type): raise TypeError( - f"annotation {i} is not a type but is {annot_values[i]!r}." + f"annotation {i} is not a type but is " + f"{type(annot_values[i])!r}, " + f"annot_values[i]={annot_values[i]!r}, " f"for function {self.f} " f"from module {self.f.__module__!r}." ) if isinstance(v, (EagerTensor, JitTensor)) and ( i >= len(annot_values) or issubclass(annot_values[i], TensorType) ): - constraints[iname] = v.tensor_type_dims + constraints[iname] = v.tensor_type_dims(annot_values[i].name) + elif ( + v is None + and i < len(annot_values) + and issubclass(annot_values[i], OptTensorType) + ): + constraints[iname] = annot_values[i] + kwargs_to_input[iname] = i, annot_values[i] else: new_kwargs[iname] = v input_to_kwargs[i] = iname if self.input_to_kwargs_ is None: - self.n_inputs_ = len(values) - len(input_to_kwargs) + self.n_inputs_ = ( + len(values) - len(input_to_kwargs) + len(kwargs_to_input) + ) self.input_to_kwargs_ = input_to_kwargs - elif self.input_to_kwargs_ != input_to_kwargs: + self.kwargs_to_input_ = kwargs_to_input + elif ( + self.input_to_kwargs_ != input_to_kwargs + or self.input_to_kwargs_ != input_to_kwargs + ): raise RuntimeError( f"Unexpected input and argument. Previous call produced " - f"self.input_to_kwargs_={self.input_to_kwargs_} and " - f"input_to_kwargs={input_to_kwargs} for function {self.f} " - f"from module {self.f.__module__!r}." + f"self.input_to_kwargs_={self.input_to_kwargs_}, " + f"self.kwargs_to_input_={self.kwargs_to_input_}, " + f"self.n_inputs_={self.n_inputs_} and " + f"input_to_kwargs={input_to_kwargs}, " + f"kwargs_to_input={kwargs_to_input} for function {self.f} " + f"from module {self.f.__module__!r}, " + f"len(values)={len(values)}, kwargs={kwargs!r}." ) - elif self.input_to_kwargs_: + elif self.input_to_kwargs_ or self.kwargs_to_input_: constraints = {} new_kwargs = {} for i, (v, iname) in enumerate(zip(values, names)): @@ -233,25 +278,28 @@ def to_jit(self, *values, **kwargs): ) and i not in self.input_to_kwargs_ ): - constraints[iname] = v.tensor_type_dims + constraints[iname] = v.tensor_type_dims(iname) else: new_kwargs[iname] = v else: names = [f"x{i}" for i in range(len(values))] new_kwargs = {} constraints = { - iname: v.tensor_type_dims + iname: v.tensor_type_dims(iname) for i, (v, iname) in enumerate(zip(values, names)) if isinstance(v, (EagerTensor, JitTensor)) } self.n_inputs_ = len(values) self.input_to_kwargs_ = {} + self.kwargs_to_input_ = {} if self.output_types is not None: constraints.update(self.output_types) inputs = [ - Input(iname) for iname, v in zip(names, values) if iname in constraints + Input(iname, annotation=constraints[iname]) + for iname, v in zip(names, values) + if iname in constraints ] names = [i.name for i in inputs] if len(new_kwargs) > 0: @@ -262,8 +310,14 @@ def to_jit(self, *values, **kwargs): else: kwargs = kwargs.copy() kwargs.update(new_kwargs) - - var = self.f(*inputs, **kwargs) + try: + var = self.f(*inputs, **kwargs) + except TypeError as e: + raise TypeError( + f"Unexpected error, inputs={inputs}, kwargs={kwargs}, " + f"self.input_to_kwargs_={self.input_to_kwargs_}, " + f"self.kwargs_to_input_={self.kwargs_to_input_}." + ) from e onx = var.to_onnx( constraints=constraints, @@ -361,9 +415,14 @@ def jit_call(self, *values, **kwargs): # No jitting was ever called. try: onx, fct = self.to_jit(*values, **kwargs) + except TypeError as e: + raise TypeError( + f"ERROR with self.f={self.f}, " + f"values={values!r}, kwargs={kwargs!r}" + ) from e except Exception as e: raise RuntimeError( - f"ERROR with self.f={self.f}, " + f"Undefined ERROR with self.f={self.f}, " f"values={values!r}, kwargs={kwargs!r}" ) from e if self.input_to_kwargs_ is None: @@ -371,6 +430,11 @@ def jit_call(self, *values, **kwargs): f"Attribute 'input_to_kwargs_' should be set for " f"function {self.f} form module {self.f.__module__!r}." ) + if self.kwargs_to_input_ is None: + raise RuntimeError( + f"Attribute 'kwargs_to_input_' should be set for " + f"function {self.f} form module {self.f.__module__!r}." + ) else: onx, fct = None, None diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index f89ed9f..5a41cc8 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -42,7 +42,7 @@ def run(self, *inputs: List["NumpyTensor"]) -> List["NumpyTensor"]: ) feeds = {} for name, inp in zip(self.input_names, inputs): - feeds[name] = inp.value + feeds[name] = None if inp is None else inp.value res = self.ref.run(None, feeds) return list(map(self.tensor_class, res)) @@ -122,16 +122,18 @@ def shape(self) -> Tuple[int, ...]: "Returns the shape of the tensor." return self._tensor.shape - @property - def tensor_type_dims(self) -> TensorType: + def tensor_type_dims(self, name: str) -> TensorType: """ Returns the tensor type of this tensor. This property is used to define a key used to cache a jitted function. Same keys keys means same ONNX graph. Different keys usually means same ONNX graph but different input shapes. + + :param name: name of the constraint """ - return TensorType[self.dtype, self.dims] + dt = self.dtype + return TensorType[dt, self.dims, name] @classmethod def create_function(cls: Any, input_names: List[str], onx: ModelProto) -> Callable: diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index 0f7f6dc..f9029f8 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -55,6 +55,8 @@ def np_dtype(self) -> "np.dtype": def __eq__(self, dt: "DType") -> bool: "Compares two types." + if dt is None: + return False if dt.__class__ is DType: return self.code_ == dt.code_ if isinstance(dt, (int, bool, str)): @@ -68,6 +70,8 @@ def __eq__(self, dt: "DType") -> bool: if dt in ElemType.numpy_map: dti = ElemType.numpy_map[dt] return self.code_ == dti.code_ + if isinstance(dt, type) and issubclass(dt, ElemType): + return self.code_ == dt.dtype.code_ try: dti = np_dtype_to_tensor_dtype(dt) except KeyError: @@ -93,12 +97,12 @@ def type_name(cls) -> str: class _DType2(DType): - "Wraps an into a different type." + "Wraps a type into a different type." pass class _DTypes(DType): - "Wraps an into a different type." + "Wraps a type into a different type." pass @@ -109,23 +113,23 @@ class ElemTypeCstInner(WrapperType): __slots__ = [] - undefined = DType(0) - bool_ = DType(9) - int8 = DType(3) - int16 = DType(5) - int32 = DType(6) - int64 = DType(7) - uint8 = DType(2) - uint16 = DType(4) - uint32 = DType(12) - uint64 = DType(13) - float16 = DType(10) - float32 = DType(1) - float64 = DType(11) - bfloat16 = DType(16) - complex64 = DType(14) - complex128 = DType(15) - str_ = DType(8) + undefined = DType(TensorProto.UNDEFINED) # 0 + bool_ = DType(TensorProto.BOOL) # 9 + int8 = DType(TensorProto.INT8) # 3 + int16 = DType(TensorProto.INT16) # 5 + int32 = DType(TensorProto.INT32) # 6 + int64 = DType(TensorProto.INT64) # 7 + uint8 = DType(TensorProto.UINT8) # 2 + uint16 = DType(TensorProto.UINT16) # 4 + uint32 = DType(TensorProto.UINT32) # 12 + uint64 = DType(TensorProto.UINT64) # 13 + float16 = DType(TensorProto.FLOAT16) # 10 + float32 = DType(TensorProto.FLOAT) # 1 + float64 = DType(TensorProto.DOUBLE) # 11 + bfloat16 = DType(TensorProto.BFLOAT16) # 16 + complex64 = DType(TensorProto.COMPLEX64) # 14 + complex128 = DType(TensorProto.COMPLEX128) # 15 + str_ = DType(TensorProto.STRING) # 8 class ElemTypeCstSet(ElemTypeCstInner): @@ -250,6 +254,9 @@ class ElemType(ElemTypeCst): @classmethod def __class_getitem__(cls, dtype: Union[str, DType]): + """ + Returns a subclass of this one with attribute `dtype`. + """ if isinstance(dtype, str): dtype = ElemType.names_int[dtype] elif dtype in ElemType.numpy_map: @@ -422,8 +429,14 @@ class TensorType(WrapperType): :param name: name of the type """ + main_name = "TensorType" + @classmethod def __class_getitem__(cls, *args): + """ + Returns a subclass of this one with two attributes `dtypes` + and `shape`. + """ if isinstance(args, tuple) and len(args) == 1 and isinstance(args[0], tuple): args = args[0] name = None @@ -500,8 +513,20 @@ def __class_getitem__(cls, *args): ) if "<" in newt.__name__: raise NameError(f"Name is wrong {newt.__name__!r}.") + if newt.name is None: + raise RuntimeError( + f"A constraint needs a name but none is given: args={args}." + ) return newt + @classmethod + def supports_dtype(cls, dtype: DType) -> bool: + """ + Determines if the element type `dtype` + is within `dtypes`. + """ + return dtype in cls.dtypes + @classmethod def type_name(cls) -> str: "Returns its full name." @@ -515,14 +540,14 @@ def type_name(cls) -> str: set_name = repr(st) if cls.shape: if cls.name: - newt = f"TensorType[{set_name}, {cls.shape!r}, {cls.name!r}]" + newt = f"{cls.main_name}[{set_name}, {cls.shape!r}, {cls.name!r}]" else: - newt = f"TensorType[{set_name}, {cls.shape!r}]" + newt = f"{cls.main_name}[{set_name}, {cls.shape!r}]" elif cls.name: - newt = f"TensorType[{set_name}, {cls.name!r}]" + newt = f"{cls.main_name}[{set_name}, {cls.name!r}]" else: - newt = f"TensorType[{set_name}]" - if "<" in newt or "{" in newt: + newt = f"{cls.main_name}[{set_name}]" + if "<" in newt: raise NameError(f"Name is wrong {newt!r}.") return newt @@ -560,6 +585,16 @@ def issuperset(cls, tensor_type: type) -> bool: return True +class OptTensorType(TensorType): + """ + Defines an optional tensor type. + + :param dtype: element type + """ + + main_name = "OptTensorType" + + class SequenceType(WrapperType): """ Defines a sequence of tensors. @@ -661,7 +696,7 @@ def _make_type(name: str, elem_type: int): def class_getitem(cls, shape: Union[int, ShapeType]) -> TensorType: if isinstance(shape, int): shape = (shape,) - return TensorType[elem_type, shape] + return TensorType[elem_type, shape, f"mtx{elem_type}"] new_type = type(name, tuple(), {}) new_type.__class_getitem__ = classmethod(class_getitem) diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index a4802e3..90022c6 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -1,10 +1,18 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -from onnx import FunctionProto, ModelProto, NodeProto +from onnx import FunctionProto, ModelProto, NodeProto, TensorProto from .._helpers import np_dtype_to_tensor_dtype from .npx_array_api import BaseArrayApi, ArrayApiError from .npx_constants import DEFAULT_OPSETS, ONNX_DOMAIN -from .npx_types import DType, ElemType, OptParType, ParType, TensorType, TupleType +from .npx_types import ( + DType, + ElemType, + OptParType, + ParType, + TensorType, + OptTensorType, + TupleType, +) class Par: @@ -280,18 +288,19 @@ def __getitem__(self, *args): def __init__( self, *inputs: List[Any], - op: Union[ - Callable, str, Tuple[str, str], FunctionProto, ModelProto, NodeProto + op: Optional[ + Union[Callable, str, Tuple[str, str], FunctionProto, ModelProto, NodeProto] ] = None, - dtype: Union[type, DType] = None, + dtype: Optional[Union[type, DType]] = None, inline: bool = False, - n_var_outputs: Optional[int] = 1, + n_var_outputs: int = 1, input_indices: Optional[List[int]] = None, **kwargs, ): self.inputs = list(inputs) self.n_var_outputs = n_var_outputs self.inline = inline + self._annotation = None if op is None: self.onnx_op = None # a constant elif isinstance(op, tuple): @@ -354,6 +363,16 @@ def __init__( self.set = Var._setter(self) self.current_var_ = None + @property + def annotation(self): + """Returns a type if known for the Var itself.""" + if self._annotation is None: + if "dtype" in self.onnx_op_kwargs: + dtype = self.onnx_op_kwargs["dtype"] + if isinstance(dtype, DType): + return TensorType[dtype] + return self._annotation + @property def self_var(self): """ @@ -852,7 +871,7 @@ def reshape(self, shape: "Var") -> "Var": def reduce_function( self, reduce_op, - axis: TensorType[ElemType.int64, "I"] = None, + axis: OptTensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0, ) -> "Var": "See :func:`numpy.sum` or any other reduce function." @@ -962,11 +981,37 @@ def __getitem__(self, index: Any) -> "Var": if isinstance(index, Var): # scenario 2 - # TODO: fix this when index is an integer - new_shape = cst(np.array([-1], dtype=np.int64)) - new_self = self.reshape(new_shape) - new_index = index.reshape(new_shape) - return var(new_self, new_index, op="Compress") + # we rely on the annotation if it exists + if index.annotation is None: + dtype_bool = True + elif issubclass(index.annotation, TensorType): + if index.annotation.supports_dtype( + DType(TensorProto.INT64) + ) or index.annotation.supports_dtype(DType(TensorProto.INT32)): + dtype_bool = False + elif index.annotation.supports_dtype(DType(TensorProto.BOOL)): + dtype_bool = True + else: + raise TypeError( + f"Unexpected dtype for annotation={index.annotation!r} " + f"for index={index!r}." + ) + else: + raise TypeError( + f"Unexpected annotation={index.annotation!r} " + f"for index={index!r}." + ) + + if dtype_bool: + # TODO: fix this when index is an integer and the annotation unknown + # it needs to support subgraph and tests + new_shape = cst(np.array([-1], dtype=np.int64)) + new_self = self.reshape(new_shape) + new_index = index.reshape(new_shape) + return var(new_self, new_index, op="Compress") + + # dtype is int + return var(self, index, axis=0, op="Gather") if isinstance(index, int): # Use Gather instead. @@ -1089,15 +1134,26 @@ class Input(Var): Defines an input, a placeholder. :param name: input name or None if undefined + :param annotation: annotation if any is available """ - def __init__(self, name=None): + def __init__(self, name: str = None, annotation: Optional[type] = None): Var.__init__(self) self.name = name self._prefix = name or "I" + self._annotation = annotation def __repr__(self): - return f"{self.__class__.__name__}({self.name!r})" + if self.annotation is None: + return f"{self.__class__.__name__}({self.name!r})" + return ( + f"{self.__class__.__name__}({self.name!r}, " + f"{self._annotation.__name__!r})" + ) + + @property + def annotation(self): + return self._annotation class Cst(Var): diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py index db9d4d5..f4f447d 100644 --- a/onnx_array_api/ort/ort_tensors.py +++ b/onnx_array_api/ort/ort_tensors.py @@ -182,16 +182,18 @@ def dims(self): return tuple(self._tensor.shape()) return (None, *tuple(self.shape[1:])) - @property - def tensor_type_dims(self) -> TensorType: + def tensor_type_dims(self, name: str) -> TensorType: """ Returns the tensor type of this tensor. This property is used to define a key used to cache a jitted function. Same keys keys means same ONNX graph. Different keys usually means same ONNX graph but different input shapes. + + :param name: name of the constraint """ - return TensorType[self.dtype, self.dims] + dt = self.dtype + return TensorType[dt, self.dims, name] @classmethod def create_function(cls: Any, input_names: List[str], onx: ModelProto) -> Callable: