diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index 3cdbb31..eef3e70 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -1,6 +1,6 @@ # API failures +# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt array_api_tests/test_creation_functions.py::test_arange -array_api_tests/test_creation_functions.py::test_asarray_scalars array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index b32ee41..cb32fe4 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,2 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros \ No newline at end of file +# pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1 +# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help +pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 \ No newline at end of file diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index bdf870c..bd79ecf 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -6,7 +6,11 @@ class TestOnnxNumpy(ExtTestCase): - def test_abs(self): + def test_empty(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + self.assertRaise(lambda: xp.empty(c, dtype=xp.int64), RuntimeError) + + def test_zeros(self): c = EagerTensor(np.array([4, 5], dtype=np.int64)) mat = xp.zeros(c, dtype=xp.int64) matnp = mat.numpy() diff --git a/_unittests/ut_npx/test_npx.py b/_unittests/ut_npx/test_npx.py index c9ee35f..93f2b5e 100644 --- a/_unittests/ut_npx/test_npx.py +++ b/_unittests/ut_npx/test_npx.py @@ -2501,7 +2501,7 @@ def test_numpy_all(self): got = ref.run(None, {"A": data}) self.assertEqualArray(y, got[0]) - def test_numpy_all_empty(self): + def test_numpy_all_zeros(self): data = np.zeros((0,), dtype=np.bool_) y = np.all(data) @@ -2513,7 +2513,7 @@ def test_numpy_all_empty(self): self.assertEqualArray(y, got[0]) @unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0") - def test_numpy_all_empty_axis_0(self): + def test_numpy_all_zeros_axis_0(self): data = np.zeros((0, 1), dtype=np.bool_) y = np.all(data, axis=0) @@ -2535,7 +2535,13 @@ def test_numpy_all_empty_axis_1(self): got = ref.run(None, {"A": data}) self.assertEqualArray(y, got[0]) + @unittest.skipIf(True, reason="Fails to follow Array API") + def test_get_item(self): + a = EagerNumpyTensor(np.array([True], dtype=np.bool_)) + i = a[0] + self.assertEqualArray(i.numpy(), a.numpy()[0]) + if __name__ == "__main__": - # TestNpx().test_numpy_all_empty_axis_0() + # TestNpx().test_get_item() unittest.main(verbosity=2) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b711ecf..ca24462 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -131,12 +131,12 @@ jobs: - script: | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy cd array-api-tests - python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v + python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt --hypothesis-explain displayName: "numpy test_creation_functions.py" - script: | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort cd array-api-tests - python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt -v + python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt --hypothesis-explain displayName: "ort test_creation_functions.py" #- script: | # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 25ace54..6553137 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -34,7 +34,15 @@ def template_asarray( return a.astype(dtype=dtype) if isinstance(a, int): - v = TEagerTensor(np.array(a, dtype=np.int64)) + if a is False: + v = TEagerTensor(np.array(False, dtype=np.bool_)) + elif a is True: + v = TEagerTensor(np.array(True, dtype=np.bool_)) + else: + try: + v = TEagerTensor(np.asarray(a, dtype=np.int64)) + except OverflowError: + v = TEagerTensor(np.asarray(a, dtype=np.uint64)) elif isinstance(a, float): v = TEagerTensor(np.array(a, dtype=np.float32)) elif isinstance(a, bool): diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index c20fb15..2cd4bfd 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -29,6 +29,7 @@ "all", "asarray", "astype", + "empty", "equal", "isdtype", "isfinite", @@ -73,6 +74,17 @@ def ones( return generic_ones(shape, dtype=dtype, order=order) +def empty( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + raise RuntimeError( + "ONNX assumes there is no inplace implementation. " + "empty function is only used in that case." + ) + + def zeros( shape: TensorType[ElemType.int64, "I", (None,)], dtype: OptParType[DType] = DType(TensorProto.FLOAT), diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index 15f9588..f89ed9f 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -1,6 +1,6 @@ from typing import Any, Callable, List, Optional, Tuple import numpy as np -from onnx import ModelProto +from onnx import ModelProto, TensorProto from onnx.reference import ReferenceEvaluator from .._helpers import np_dtype_to_tensor_dtype from .npx_numpy_tensors_ops import ConstantOfShape @@ -183,6 +183,60 @@ def __array_namespace__(self, api_version: Optional[str] = None): f"Unable to return an implementation for api_version={api_version!r}." ) + def __bool__(self): + "Implicit conversion to bool." + if self.dtype != DType(TensorProto.BOOL): + raise TypeError( + f"Conversion to bool only works for bool scalar, not for {self!r}." + ) + if self.shape == (0,): + return False + if len(self.shape) != 0: + raise ValueError( + f"Conversion to bool only works for scalar, not for {self!r}." + ) + return bool(self._tensor) + + def __int__(self): + "Implicit conversion to bool." + if len(self.shape) != 0: + raise ValueError( + f"Conversion to bool only works for scalar, not for {self!r}." + ) + if self.dtype not in { + DType(TensorProto.INT64), + DType(TensorProto.INT32), + DType(TensorProto.INT16), + DType(TensorProto.INT8), + DType(TensorProto.UINT64), + DType(TensorProto.UINT32), + DType(TensorProto.UINT16), + DType(TensorProto.UINT8), + }: + raise TypeError( + f"Conversion to int only works for int scalar, " + f"not for dtype={self.dtype}." + ) + return int(self._tensor) + + def __float__(self): + "Implicit conversion to bool." + if len(self.shape) != 0: + raise ValueError( + f"Conversion to bool only works for scalar, not for {self!r}." + ) + if self.dtype not in { + DType(TensorProto.FLOAT), + DType(TensorProto.DOUBLE), + DType(TensorProto.FLOAT16), + DType(TensorProto.BFLOAT16), + }: + raise TypeError( + f"Conversion to int only works for float scalar, " + f"not for dtype={self.dtype}." + ) + return float(self._tensor) + class JitNumpyTensor(NumpyTensor, JitTensor): """ diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 42b1b5a..2759f4c 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -962,6 +962,7 @@ def __getitem__(self, index: Any) -> "Var": if isinstance(index, Var): # scenario 2 + # TODO: fix this when index is an integer new_shape = cst(np.array([-1], dtype=np.int64)) new_self = self.reshape(new_shape) new_index = index.reshape(new_shape) @@ -973,6 +974,9 @@ def __getitem__(self, index: Any) -> "Var": if not isinstance(index, tuple): index = (index,) + elif len(index) == 0: + # The array contains a scalar and it needs to be returned. + return var(self, op="Identity") # only one integer? ni = None