diff --git a/.gitignore b/.gitignore index 136737c..f4d6253 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ _cache/* dist/* build/* .eggs/* +.hypothesis/* *egg-info/* _doc/auto_examples/* _doc/examples/_cache/* diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt new file mode 100644 index 0000000..3cdbb31 --- /dev/null +++ b/_unittests/onnx-numpy-skips.txt @@ -0,0 +1,13 @@ +# API failures +array_api_tests/test_creation_functions.py::test_arange +array_api_tests/test_creation_functions.py::test_asarray_scalars +array_api_tests/test_creation_functions.py::test_asarray_arrays +array_api_tests/test_creation_functions.py::test_empty +array_api_tests/test_creation_functions.py::test_empty_like +array_api_tests/test_creation_functions.py::test_eye +array_api_tests/test_creation_functions.py::test_full +array_api_tests/test_creation_functions.py::test_full_like +array_api_tests/test_creation_functions.py::test_linspace +array_api_tests/test_creation_functions.py::test_meshgrid +array_api_tests/test_creation_functions.py::test_ones_like +array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/onnx-ort-skips.txt b/_unittests/onnx-ort-skips.txt new file mode 100644 index 0000000..557d14b --- /dev/null +++ b/_unittests/onnx-ort-skips.txt @@ -0,0 +1,15 @@ +# Not implementated by onnxruntime +array_api_tests/test_creation_functions.py::test_arange +array_api_tests/test_creation_functions.py::test_asarray_scalars +array_api_tests/test_creation_functions.py::test_asarray_arrays +array_api_tests/test_creation_functions.py::test_empty +array_api_tests/test_creation_functions.py::test_empty_like +array_api_tests/test_creation_functions.py::test_eye +array_api_tests/test_creation_functions.py::test_full +array_api_tests/test_creation_functions.py::test_full_like +array_api_tests/test_creation_functions.py::test_linspace +array_api_tests/test_creation_functions.py::test_meshgrid +array_api_tests/test_creation_functions.py::test_ones +array_api_tests/test_creation_functions.py::test_ones_like +array_api_tests/test_creation_functions.py::test_zeros +array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/ut_array_api/test_array_apis.py b/_unittests/ut_array_api/test_array_apis.py new file mode 100644 index 0000000..c72700c --- /dev/null +++ b/_unittests/ut_array_api/test_array_apis.py @@ -0,0 +1,112 @@ +import unittest +from inspect import isfunction, ismethod +import numpy as np +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.array_api import onnx_numpy as xpn +from onnx_array_api.array_api import onnx_ort as xpo + +# from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor +# from onnx_array_api.ort.ort_tensors import EagerOrtTensor + + +class TestArraysApis(ExtTestCase): + def test_zeros_numpy_1(self): + c = xpn.zeros(1) + d = c.numpy() + self.assertEqualArray(np.array([0], dtype=np.float32), d) + + def test_zeros_ort_1(self): + c = xpo.zeros(1) + d = c.numpy() + self.assertEqualArray(np.array([0], dtype=np.float32), d) + + def test_ffinfo(self): + dt = np.float32 + fi1 = np.finfo(dt) + fi2 = xpn.finfo(dt) + fi3 = xpo.finfo(dt) + dt1 = fi1.dtype + dt2 = fi2.dtype + dt3 = fi3.dtype + self.assertEqual(dt2, dt3) + self.assertNotEqual(dt1.__class__, dt2.__class__) + mi1 = fi1.min + mi2 = fi2.min + self.assertEqual(mi1, mi2) + mi1 = fi1.smallest_normal + mi2 = fi2.smallest_normal + self.assertEqual(mi1, mi2) + for n in dir(fi1): + if n.startswith("__"): + continue + if n in {"machar"}: + continue + v1 = getattr(fi1, n) + with self.subTest(att=n): + v2 = getattr(fi2, n) + v3 = getattr(fi3, n) + if isfunction(v1) or ismethod(v1): + try: + v1 = v1() + except TypeError: + continue + v2 = v2() + v3 = v3() + if v1 != v2: + raise AssertionError( + f"12: info disagree on name {n!r}: {v1} != {v2}, " + f"type(v1)={type(v1)}, type(v2)={type(v2)}, " + f"ismethod={ismethod(v1)}." + ) + if v2 != v3: + raise AssertionError( + f"23: info disagree on name {n!r}: {v2} != {v3}, " + f"type(v1)={type(v1)}, type(v2)={type(v2)}, " + f"ismethod={ismethod(v1)}." + ) + + def test_iiinfo(self): + dt = np.int64 + fi1 = np.iinfo(dt) + fi2 = xpn.iinfo(dt) + fi3 = xpo.iinfo(dt) + dt1 = fi1.dtype + dt2 = fi2.dtype + dt3 = fi3.dtype + self.assertEqual(dt2, dt3) + self.assertNotEqual(dt1.__class__, dt2.__class__) + mi1 = fi1.min + mi2 = fi2.min + self.assertEqual(mi1, mi2) + for n in dir(fi1): + if n.startswith("__"): + continue + if n in {"machar"}: + continue + v1 = getattr(fi1, n) + with self.subTest(att=n): + v2 = getattr(fi2, n) + v3 = getattr(fi3, n) + if isfunction(v1) or ismethod(v1): + try: + v1 = v1() + except TypeError: + continue + v2 = v2() + v3 = v3() + if v1 != v2: + raise AssertionError( + f"12: info disagree on name {n!r}: {v1} != {v2}, " + f"type(v1)={type(v1)}, type(v2)={type(v2)}, " + f"ismethod={ismethod(v1)}." + ) + if v2 != v3: + raise AssertionError( + f"23: info disagree on name {n!r}: {v2} != {v3}, " + f"type(v1)={type(v1)}, type(v2)={type(v2)}, " + f"ismethod={ismethod(v1)}." + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 30e2ca2..bdf870c 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -2,12 +2,12 @@ import numpy as np from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.array_api import onnx_numpy as xp -from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor +from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor as EagerTensor class TestOnnxNumpy(ExtTestCase): def test_abs(self): - c = EagerNumpyTensor(np.array([4, 5], dtype=np.int64)) + c = EagerTensor(np.array([4, 5], dtype=np.int64)) mat = xp.zeros(c, dtype=xp.int64) matnp = mat.numpy() self.assertEqual(matnp.shape, (4, 5)) diff --git a/_unittests/ut_array_api/test_onnx_ort.py b/_unittests/ut_array_api/test_onnx_ort.py new file mode 100644 index 0000000..a10b0d0 --- /dev/null +++ b/_unittests/ut_array_api/test_onnx_ort.py @@ -0,0 +1,20 @@ +import unittest +import numpy as np +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.array_api import onnx_ort as xp +from onnx_array_api.ort.ort_tensors import EagerOrtTensor as EagerTensor + + +class TestOnnxOrt(ExtTestCase): + def test_abs(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.zeros(c, dtype=xp.int64) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + a = xp.absolute(mat) + self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_ort/test_ort_tensor.py b/_unittests/ut_ort/test_ort_tensor.py index 57340d5..b673557 100644 --- a/_unittests/ut_ort/test_ort_tensor.py +++ b/_unittests/ut_ort/test_ort_tensor.py @@ -1,18 +1,17 @@ import unittest from contextlib import redirect_stdout from io import StringIO - import numpy as np from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator from onnxruntime import InferenceSession - from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.npx import eager_onnx, jit_onnx from onnx_array_api.npx.npx_functions import absolute as absolute_inline from onnx_array_api.npx.npx_functions import cdist as cdist_inline from onnx_array_api.npx.npx_functions_test import absolute -from onnx_array_api.npx.npx_types import Float32, Float64 +from onnx_array_api.npx.npx_functions import copy as copy_inline +from onnx_array_api.npx.npx_types import Float32, Float64, DType from onnx_array_api.npx.npx_var import Input from onnx_array_api.ort.ort_tensors import EagerOrtTensor, JitOrtTensor, OrtTensor @@ -193,6 +192,49 @@ def impl(xa, xb): if len(pieces) > 2: raise AssertionError(f"Function is not using argument:\n{onx}") + def test_astype(self): + f = absolute_inline(copy_inline(Input("A")).astype(np.float32)) + onx = f.to_onnx(constraints={"A": Float64[None]}) + x = np.array([[-5, 6]], dtype=np.float64) + z = np.abs(x.astype(np.float32)) + ref = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = ref.run(None, {"A": x}) + self.assertEqualArray(z, got[0]) + + def test_astype0(self): + f = absolute_inline(copy_inline(Input("A")).astype(np.float32)) + onx = f.to_onnx(constraints={"A": Float64[None]}) + x = np.array(-5, dtype=np.float64) + z = np.abs(x.astype(np.float32)) + ref = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = ref.run(None, {"A": x}) + self.assertEqualArray(z, got[0]) + + def test_eager_ort_cast(self): + def impl(A): + return A.astype(DType("FLOAT")) + + e = eager_onnx(impl) + self.assertEqual(len(e.versions), 0) + + # Float64 + x = np.array([0, 1, -2], dtype=np.float64) + z = x.astype(np.float32) + res = e(x) + self.assertEqualArray(z, res) + self.assertEqual(res.dtype, np.float32) + + # again + x = np.array(1, dtype=np.float64) + z = x.astype(np.float32) + res = e(x) + self.assertEqualArray(z, res) + self.assertEqual(res.dtype, np.float32) + if __name__ == "__main__": # TestNpx().test_eager_numpy() diff --git a/azure-pipelines.yml b/azure-pipelines.yml index defe983..b711ecf 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -110,6 +110,8 @@ jobs: displayName: 'Install tools' - script: pip install -r requirements.txt displayName: 'Install Requirements' + - script: pip install onnxruntime + displayName: 'Install onnxruntime' - script: python setup.py install displayName: 'Install onnx_array_api' - script: | @@ -129,8 +131,13 @@ jobs: - script: | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy cd array-api-tests - python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros - displayName: "test_creation_functions.py::test_zeros" + python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v + displayName: "numpy test_creation_functions.py" + - script: | + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort + cd array-api-tests + python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt -v + displayName: "ort test_creation_functions.py" #- script: | # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy # cd array-api-tests @@ -246,16 +253,8 @@ jobs: displayName: 'export' - script: gcc --version displayName: 'gcc version' - - script: brew install llvm - displayName: 'install llvm' - - script: brew install libomp - displayName: 'Install omp' - - script: brew install p7zip - displayName: 'Install p7zip' - script: python -m pip install --upgrade pip setuptools wheel displayName: 'Install tools' - - script: brew install pybind11 - displayName: 'Install pybind11' - script: pip install -r requirements.txt displayName: 'Install Requirements' - script: pip install -r requirements-dev.txt diff --git a/onnx_array_api/_helpers.py b/onnx_array_api/_helpers.py new file mode 100644 index 0000000..6191c92 --- /dev/null +++ b/onnx_array_api/_helpers.py @@ -0,0 +1,45 @@ +import numpy as np +from typing import Any +from onnx import helper, TensorProto + + +def np_dtype_to_tensor_dtype(dtype: Any): + """ + Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`. + """ + try: + dt = helper.np_dtype_to_tensor_dtype(dtype) + except KeyError: + if dtype == np.float32: + dt = TensorProto.FLOAT + elif dtype == np.float64: + dt = TensorProto.DOUBLE + elif dtype == np.int64: + dt = TensorProto.INT64 + elif dtype == np.int32: + dt = TensorProto.INT32 + elif dtype == np.int16: + dt = TensorProto.INT16 + elif dtype == np.int8: + dt = TensorProto.INT8 + elif dtype == np.uint64: + dt = TensorProto.UINT64 + elif dtype == np.uint32: + dt = TensorProto.UINT32 + elif dtype == np.uint16: + dt = TensorProto.UINT16 + elif dtype == np.uint8: + dt = TensorProto.UINT8 + elif dtype == np.float16: + dt = TensorProto.FLOAT16 + elif dtype in (bool, np.bool_): + dt = TensorProto.BOOL + elif dtype in (str, np.str_): + dt = TensorProto.STRING + elif dtype is int: + dt = TensorProto.INT64 + elif dtype is float: + dt = TensorProto.FLOAT64 + else: + raise KeyError(f"Unable to guess type for dtype={dtype}.") + return dt diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index e13b184..cc64b8e 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -1,8 +1,42 @@ +import numpy as np from onnx import TensorProto +from .._helpers import np_dtype_to_tensor_dtype from ..npx.npx_types import DType +def _finfo(dtype): + """ + Similar to :class:`numpy.finfo`. + """ + dt = dtype.np_dtype if isinstance(dtype, DType) else dtype + res = np.finfo(dt) + d = res.__dict__.copy() + d["dtype"] = DType(np_dtype_to_tensor_dtype(dt)) + nres = type("finfo", (res.__class__,), d) + setattr(nres, "smallest_normal", res.smallest_normal) + setattr(nres, "tiny", res.tiny) + return nres + + +def _iinfo(dtype): + """ + Similar to :class:`numpy.finfo`. + """ + dt = dtype.np_dtype if isinstance(dtype, DType) else dtype + res = np.iinfo(dt) + d = res.__dict__.copy() + d["dtype"] = DType(np_dtype_to_tensor_dtype(dt)) + nres = type("finfo", (res.__class__,), d) + setattr(nres, "min", res.min) + setattr(nres, "max", res.max) + return nres + + def _finalize_array_api(module): + """ + Adds common attributes to Array API defined in this modules + such as types. + """ module.float16 = DType(TensorProto.FLOAT16) module.float32 = DType(TensorProto.FLOAT) module.float64 = DType(TensorProto.DOUBLE) @@ -17,3 +51,5 @@ def _finalize_array_api(module): module.bfloat16 = DType(TensorProto.BFLOAT16) setattr(module, "bool", DType(TensorProto.BOOL)) setattr(module, "str", DType(TensorProto.STRING)) + setattr(module, "finfo", _finfo) + setattr(module, "iinfo", _iinfo) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 8d136c4..25ace54 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -41,6 +41,8 @@ def template_asarray( v = TEagerTensor(np.array(a, dtype=np.bool_)) elif isinstance(a, str): v = TEagerTensor(np.array(a, dtype=np.str_)) + elif isinstance(a, list): + v = TEagerTensor(np.array(a)) else: raise RuntimeError(f"Unexpected type {type(a)} for the first input.") if dtype is not None: diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 79b339d..c20fb15 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -11,9 +11,12 @@ astype, equal, isdtype, + isfinite, + isnan, reshape, take, ) +from ..npx.npx_functions import ones as generic_ones from ..npx.npx_functions import zeros as generic_zeros from ..npx.npx_numpy_tensors import EagerNumpyTensor from ..npx.npx_types import DType, ElemType, TensorType, OptParType @@ -28,6 +31,9 @@ "astype", "equal", "isdtype", + "isfinite", + "isnan", + "ones", "reshape", "take", "zeros", @@ -49,6 +55,24 @@ def asarray( ) +def ones( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + if isinstance(shape, tuple): + return generic_ones( + EagerNumpyTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order + ) + if isinstance(shape, int): + return generic_ones( + EagerNumpyTensor(np.array([shape], dtype=np.int64)), + dtype=dtype, + order=order, + ) + return generic_ones(shape, dtype=dtype, order=order) + + def zeros( shape: TensorType[ElemType.int64, "I", (None,)], dtype: OptParType[DType] = DType(TensorProto.FLOAT), @@ -58,10 +82,20 @@ def zeros( return generic_zeros( EagerNumpyTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order ) + if isinstance(shape, int): + return generic_zeros( + EagerNumpyTensor(np.array([shape], dtype=np.int64)), + dtype=dtype, + order=order, + ) return generic_zeros(shape, dtype=dtype, order=order) def _finalize(): + """ + Adds common attributes to Array API defined in this modules + such as types. + """ from . import onnx_numpy _finalize_array_api(onnx_numpy) diff --git a/onnx_array_api/array_api/onnx_ort.py b/onnx_array_api/array_api/onnx_ort.py index 505efdf..56f6444 100644 --- a/onnx_array_api/array_api/onnx_ort.py +++ b/onnx_array_api/array_api/onnx_ort.py @@ -2,8 +2,9 @@ Array API valid for an :class:`EagerOrtTensor`. """ from typing import Optional, Any +import numpy as np +from onnx import TensorProto from ..ort.ort_tensors import EagerOrtTensor -from ..npx.npx_types import DType from ..npx.npx_functions import ( all, abs, @@ -11,9 +12,13 @@ astype, equal, isdtype, + isnan, + isfinite, reshape, take, ) +from ..npx.npx_types import DType, ElemType, TensorType, OptParType +from ..npx.npx_functions import zeros as generic_zeros from ._onnx_common import template_asarray from . import _finalize_array_api @@ -25,6 +30,8 @@ "astype", "equal", "isdtype", + "isfinite", + "isnan", "reshape", "take", ] @@ -45,7 +52,27 @@ def asarray( ) +def zeros( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + if isinstance(shape, tuple): + return generic_zeros( + EagerOrtTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order + ) + if isinstance(shape, int): + return generic_zeros( + EagerOrtTensor(np.array([shape], dtype=np.int64)), dtype=dtype, order=order + ) + return generic_zeros(shape, dtype=dtype, order=order) + + def _finalize(): + """ + Adds common attributes to Array API defined in this modules + such as types. + """ from . import onnx_ort _finalize_array_api(onnx_ort) diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index b55cf4d..29a4481 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -1,11 +1,10 @@ from typing import Optional, Tuple, Union - import array_api_compat.numpy as np_array_api import numpy as np from onnx import FunctionProto, ModelProto, NodeProto, TensorProto -from onnx.helper import make_tensor, np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype +from onnx.helper import make_tensor, tensor_dtype_to_np_dtype from onnx.numpy_helper import from_array - +from .._helpers import np_dtype_to_tensor_dtype from .npx_constants import FUNCTION_DOMAIN from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var from .npx_types import ( @@ -203,15 +202,7 @@ def astype( raise TypeError( f"dtype is an attribute, it cannot be a Variable of type {type(dtype)}." ) - try: - to = np_dtype_to_tensor_dtype(dtype) - except KeyError: - if dtype is int: - to = TensorProto.INT64 - elif dtype is float: - to = TensorProto.float64 - else: - raise ValueError(f"Unable to guess tensor type from {dtype}.") + to = np_dtype_to_tensor_dtype(dtype) return var(a, op="Cast", to=to) @@ -351,7 +342,7 @@ def einsum( def equal( x: TensorType[ElemType.allowed, "T"], y: TensorType[ElemType.allowed, "T"] ) -> TensorType[ElemType.bool_, "T1"]: - "See :func:`numpy.isnan`." + "See :func:`numpy.equal`." return var(x, y, op="Equal") @@ -437,6 +428,12 @@ def isdtype( return np_array_api.isdtype(dtype, kind) +@npxapi_inline +def isfinite(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T1"]: + "See :func:`numpy.isfinite`." + return var(x, op="IsInf") + + @npxapi_inline def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T1"]: "See :func:`numpy.isnan`." @@ -464,6 +461,26 @@ def matmul( return var(a, b, op="MatMul") +@npxapi_inline +def ones( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + """ + Implements :func:`numpy.zeros`. + """ + if order != "C": + raise RuntimeError(f"order={order!r} != 'C' not supported.") + if dtype is None: + dtype = DType(TensorProto.FLOAT) + return var( + shape, + value=make_tensor(name="one", data_type=dtype.code, dims=[1], vals=[1]), + op="ConstantOfShape", + ) + + @npxapi_inline def pad( x: TensorType[ElemType.numerics, "T"], diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index ec91b91..d41b91c 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -273,7 +273,7 @@ def _io( self, index: int, name: str, tensor_type: Optional[type], is_input: bool ) -> ValueInfoProto: """ - Converts an input or outut into :class:`onnx.ValueInfoProto`. + Converts an input or output into :class:`onnx.ValueInfoProto`. :param index: index of the input or output to add :param name: input or output name diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 85b52d4..35ff9af 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -267,6 +267,18 @@ def to_jit(self, *values, **kwargs): target_opsets=self.target_opsets, ir_version=self.ir_version, ) + if len(values) > 0 and len(values[0].shape) == 0: + inps = onx.graph.input[0] + shape = [] + for d in inps.type.tensor_type.shape.dim: + v = d.dim_value if d.dim_value > 0 else d.dim_param + shape.append(v) + if len(shape) != 0: + raise RuntimeError( + f"Shape mismatch, values[0]={values[0]} " + f"and inputs={onx.graph.input}." + ) + exe = self.tensor_class.create_function(names, onx) self.info("-", "to_jit") return onx, exe diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index e1a0c10..15f9588 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -1,10 +1,8 @@ from typing import Any, Callable, List, Optional, Tuple - import numpy as np from onnx import ModelProto -from onnx.helper import np_dtype_to_tensor_dtype from onnx.reference import ReferenceEvaluator - +from .._helpers import np_dtype_to_tensor_dtype from .npx_numpy_tensors_ops import ConstantOfShape from .npx_tensors import EagerTensor, JitTensor from .npx_types import DType, TensorType @@ -107,13 +105,12 @@ def dims(self): """ Returns the dimensions of the tensor. First dimension is the batch dimension if the tensor - has more than one dimension. + has more than one dimension. It is always left undefined. """ - if len(self._tensor.shape) == 0: - return (0,) - if len(self._tensor.shape) == 1: + if len(self._tensor.shape) <= 1: + # a scalar (len==0) or a 1D tensor return self._tensor.shape - return (None,) + self._tensor.shape[1:] + return (None, *tuple(self.shape[1:])) @property def ndim(self): diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index e1e4b21..b0e92c2 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -1,8 +1,6 @@ from typing import Any, Union - import numpy as np -from onnx.helper import np_dtype_to_tensor_dtype - +from .._helpers import np_dtype_to_tensor_dtype from .npx_types import DType, ElemType, ParType, TensorType from .npx_array_api import BaseArrayApi, ArrayApiError @@ -77,9 +75,9 @@ def _getitem_impl_var(obj, index, method_name=None): def _astype_impl( x: TensorType[ElemType.allowed, "T1"], dtype: ParType[DType], method_name=None ) -> TensorType[ElemType.allowed, "T2"]: - # avoids circular imports. if dtype is None: raise ValueError("dtype cannot be None.") + # avoids circular imports. from .npx_var import Var if not isinstance(x, Var): @@ -178,10 +176,6 @@ def _generic_method_reduce(self, method_name, *args: Any, **kwargs: Any) -> Any: @staticmethod def _np_dtype_to_tensor_dtype(dtype): - if dtype == int: - dtype = np.dtype("int64") - elif dtype == float: - dtype = np.dtype("float64") return np_dtype_to_tensor_dtype(dtype) def _generic_method_astype( diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index aa335bd..6063e64 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -2,7 +2,8 @@ import numpy as np from onnx import AttributeProto, TensorProto -from onnx.helper import np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype +from onnx.helper import tensor_dtype_to_np_dtype +from .._helpers import np_dtype_to_tensor_dtype class WrapperType: @@ -18,13 +19,20 @@ class DType(WrapperType): Type of the element type returned by tensors following the :epkg:`Array API`. - :param code: element type based on onnx definition + :param code: element type based on onnx definition, + if str, it looks into class :class:`onnxTensorProto` + to retrieve the code """ __slots__ = ["code_"] - def __init__(self, code: int): - self.code_ = code + def __init__(self, code: Union[int, str]): + if isinstance(code, str): + self.code_ = getattr(TensorProto, code) + elif isinstance(code, int): + self.code_ = code + else: + raise TypeError(f"Unsupported type {type(code)}:{code!r}") def __repr__(self) -> str: "usual" @@ -55,6 +63,8 @@ def __eq__(self, dt: "DType") -> bool: return self.code_ == TensorProto.STRING if dt is bool: return self.code_ == TensorProto.BOOL + if isinstance(dt, list): + return False if dt in ElemType.numpy_map: dti = ElemType.numpy_map[dt] return self.code_ == dti.code_ diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index ae5b732..42b1b5a 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -1,9 +1,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union - import numpy as np -from onnx import FunctionProto, ModelProto, NodeProto, TensorProto -from onnx.helper import np_dtype_to_tensor_dtype - +from onnx import FunctionProto, ModelProto, NodeProto +from .._helpers import np_dtype_to_tensor_dtype from .npx_array_api import BaseArrayApi, ArrayApiError from .npx_constants import DEFAULT_OPSETS, ONNX_DOMAIN from .npx_types import DType, ElemType, OptParType, ParType, TensorType, TupleType @@ -199,6 +197,15 @@ class Var(BaseArrayApi): :param onnx_input_type_: names given to the variables """ + def __array_namespace__(self, api_version: Optional[str] = None): + """ + Raises an exception if called. + """ + raise RuntimeError( + f"This function should never be called for class {type(self)}. " + f"It should be called for an eager tensor." + ) + @staticmethod def get_cst_var(): from .npx_core_api import cst, var @@ -822,38 +829,7 @@ def astype(self, dtype) -> "Var": if isinstance(dtype, Var): return var(self.self_var, dtype, op="CastLike") if not isinstance(dtype, int): - try: - dtype = np_dtype_to_tensor_dtype(dtype) - except KeyError: - if dtype == np.float32: - dtype = TensorProto.FLOAT - elif dtype == np.float64: - dtype = TensorProto.DOUBLE - elif dtype == np.int64: - dtype = TensorProto.INT64 - elif dtype == np.int32: - dtype = TensorProto.INT32 - elif dtype == np.int16: - dtype = TensorProto.INT16 - elif dtype == np.int8: - dtype = TensorProto.INT8 - elif dtype == np.uint64: - dtype = TensorProto.UINT64 - elif dtype == np.uint32: - dtype = TensorProto.UINT32 - elif dtype == np.uint16: - dtype = TensorProto.UINT16 - elif dtype == np.uint8: - dtype = TensorProto.UINT8 - elif dtype == np.float16: - dtype = TensorProto.FLOAT16 - elif dtype in (bool, np.bool_): - dtype = TensorProto.BOOL - elif dtype in (str, np.str_): - dtype = TensorProto.STRING - else: - raise RuntimeError(f"Unable to guess type for dtype={dtype}.") - + dtype = np_dtype_to_tensor_dtype(dtype) return var(self.self_var, op="Cast", to=dtype) @property @@ -976,7 +952,7 @@ def __getitem__(self, index: Any) -> "Var": cst, var = Var.get_cst_var() if self.n_var_outputs != 1: - # Multioutut + # Multioutput if not isinstance(index, int): raise TypeError( f"Only indices are allowed when selecting an output, " diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py index ead834d..db9d4d5 100644 --- a/onnx_array_api/ort/ort_tensors.py +++ b/onnx_array_api/ort/ort_tensors.py @@ -148,7 +148,7 @@ def ndim(self): @property def shape(self) -> Tuple[int, ...]: "Returns the shape of the tensor." - return self._tensor.shape() + return tuple(self._tensor.shape()) @property def dtype(self) -> DType: @@ -175,12 +175,11 @@ def dims(self): """ Returns the dimensions of the tensor. First dimension is the batch dimension if the tensor - has more than one dimension. + has more than one dimension. It is always left undefined. """ - if len(self.shape) == 0: - return (0,) - if len(self.shape) == 1: - return tuple(self.shape) + if len(self._tensor.shape()) <= 1: + # a scalar (len==0) or a 1D tensor + return tuple(self._tensor.shape()) return (None, *tuple(self.shape[1:])) @property