diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index d4c0a69dfd..d6e15f9339 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -52,6 +52,15 @@ cdef class InternalUSMArrayError(Exception): pass +cdef object _as_zero_dim_ndarray(object usm_ary): + "Convert size-1 array to NumPy 0d array" + mem_view = dpmem.as_usm_memory(usm_ary) + host_buf = mem_view.copy_to_host() + view = host_buf.view(usm_ary.dtype) + view.shape = tuple() + return view + + cdef class usm_ndarray: """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \ offset=0, order="C", buffer_ctor_kwargs=dict(), \ @@ -840,9 +849,7 @@ cdef class usm_ndarray: def __bool__(self): if self.size == 1: - mem_view = dpmem.as_usm_memory(self) - host_buf = mem_view.copy_to_host() - view = host_buf.view(self.dtype) + view = _as_zero_dim_ndarray(self) return view.__bool__() if self.size == 0: @@ -857,9 +864,7 @@ cdef class usm_ndarray: def __float__(self): if self.size == 1: - mem_view = dpmem.as_usm_memory(self) - host_buf = mem_view.copy_to_host() - view = host_buf.view(self.dtype) + view = _as_zero_dim_ndarray(self) return view.__float__() raise ValueError( @@ -868,9 +873,7 @@ cdef class usm_ndarray: def __complex__(self): if self.size == 1: - mem_view = dpmem.as_usm_memory(self) - host_buf = mem_view.copy_to_host() - view = host_buf.view(self.dtype) + view = _as_zero_dim_ndarray(self) return view.__complex__() raise ValueError( @@ -879,9 +882,7 @@ cdef class usm_ndarray: def __int__(self): if self.size == 1: - mem_view = dpmem.as_usm_memory(self) - host_buf = mem_view.copy_to_host() - view = host_buf.view(self.dtype) + view = _as_zero_dim_ndarray(self) return view.__int__() raise ValueError( diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 8e71f3931d..82b4303460 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -239,8 +239,9 @@ def test_copy_scalar_with_func(func, shape, dtype): X = dpt.usm_ndarray(shape, dtype=dtype) except dpctl.SyclDeviceCreationError: pytest.skip("No SYCL devices available") - Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape) - X.usm_data.copy_from_host(Y.reshape(-1).view("|u1")) + Y = np.arange(1, X.size + 1, dtype=dtype) + X.usm_data.copy_from_host(Y.view("|u1")) + Y.shape = tuple() assert func(X) == func(Y) @@ -254,8 +255,9 @@ def test_copy_scalar_with_method(method, shape, dtype): X = dpt.usm_ndarray(shape, dtype=dtype) except dpctl.SyclDeviceCreationError: pytest.skip("No SYCL devices available") - Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape) - X.usm_data.copy_from_host(Y.reshape(-1).view("|u1")) + Y = np.arange(1, X.size + 1, dtype=dtype) + X.usm_data.copy_from_host(Y.view("|u1")) + Y.shape = tuple() assert getattr(X, method)() == getattr(Y, method)()