From 051d3e43066ebc407694507c5da1c2cf90a8290c Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Sun, 27 Aug 2023 20:55:31 +0200 Subject: [PATCH 1/3] Leveraged dpctl.tensor.copy() implementation --- .github/workflows/conda-package.yml | 1 + dpnp/dpnp_algo/dpnp_algo_mathematical.pxi | 4 +- dpnp/dpnp_array.py | 52 ++++++- dpnp/dpnp_container.py | 24 +-- dpnp/dpnp_iface_arraycreation.py | 109 +++++++++---- tests/skipped_tests.tbl | 2 - tests/skipped_tests_gpu.tbl | 2 - tests/test_arraycreation.py | 3 +- tests/test_copy.py | 74 +++++++++ .../cupy/core_tests/test_elementwise.py | 143 ++++++++++++++++++ .../cupy/creation_tests/test_from_data.py | 1 - 11 files changed, 363 insertions(+), 52 deletions(-) create mode 100644 tests/test_copy.py create mode 100644 tests/third_party/cupy/core_tests/test_elementwise.py diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index b522fd272895..fa99a9e02fd6 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -15,6 +15,7 @@ env: test_arraycreation.py test_dot.py test_dparray.py + test_copy.py test_fft.py test_linalg.py test_logic.py diff --git a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi index d2ec12b4c3d2..7ca664fceaba 100644 --- a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi @@ -352,7 +352,7 @@ cpdef tuple dpnp_modf(utils.dpnp_descriptor x1): cpdef utils.dpnp_descriptor dpnp_nancumprod(utils.dpnp_descriptor x1): - cur_x1 = dpnp_copy(x1).get_pyobj() + cur_x1 = x1.get_pyobj().copy() cur_x1_flatiter = cur_x1.flat @@ -365,7 +365,7 @@ cpdef utils.dpnp_descriptor dpnp_nancumprod(utils.dpnp_descriptor x1): cpdef utils.dpnp_descriptor dpnp_nancumsum(utils.dpnp_descriptor x1): - cur_x1 = dpnp_copy(x1).get_pyobj() + cur_x1 = x1.get_pyobj().copy() cur_x1_flatiter = cur_x1.flat diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index d19821813b29..a73a5a6cfaaa 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -168,7 +168,15 @@ def __complex__(self): return self._array_obj.__complex__() # '__contains__', - # '__copy__', + + def __copy__(self): + """ + Used if :func:`copy.copy` is called on an array. Returns a copy of the array. + + Equivalent to ``a.copy(order="K")``. + """ + return self.copy(order="K") + # '__deepcopy__', # '__delattr__', # '__delitem__', @@ -651,7 +659,47 @@ def conjugate(self): else: return dpnp.conjugate(self) - # 'copy', + def copy(self, order="C"): + """ + Return a copy of the array. + + Returns + ------- + out : dpnp.ndarray + A copy of the array. + + See also + -------- + :obj:`dpnp.copy` : Similar function with different default behavior + :obj:`dpnp.copyto` : Copies values from one array to another. + + Notes + ----- + This function is the preferred method for creating an array copy. The + function :func:`dpnp.copy` is similar, but it defaults to using order 'K'. + + Examples + -------- + >>> import dpnp as np + >>> x = np.array([[1, 2, 3], [4, 5, 6]], order='F') + >>> y = x.copy() + >>> x.fill(0) + + >>> x + array([[0, 0, 0], + [0, 0, 0]]) + + >>> y + array([[1, 2, 3], + [4, 5, 6]]) + + >>> y.flags['C_CONTIGUOUS'] + True + + """ + + return dpnp.copy(self, order=order) + # 'ctypes', # 'cumprod', diff --git a/dpnp/dpnp_container.py b/dpnp/dpnp_container.py index 08522ceaafa6..f7e8cc355d82 100644 --- a/dpnp/dpnp_container.py +++ b/dpnp/dpnp_container.py @@ -43,6 +43,7 @@ __all__ = [ "arange", "asarray", + "copy", "empty", "eye", "full", @@ -135,6 +136,17 @@ def asarray( return dpnp_array(array_obj.shape, buffer=array_obj, order=order) +def copy(x1, /, *, order="K"): + """Creates `dpnp_array` as a copy of given instance of input array.""" + if order is None: + order = "K" + else: + order = order.upper() + + array_obj = dpt.copy(dpnp.get_usm_ndarray(x1), order=order) + return dpnp_array(array_obj.shape, buffer=array_obj, order="K") + + def empty( shape, *, @@ -264,9 +276,7 @@ def meshgrid(*xi, indexing="xy"): """Creates list of `dpnp_array` coordinate matrices from vectors.""" if len(xi) == 0: return [] - arrays = tuple( - x.get_array() if isinstance(x, dpnp_array) else x for x in xi - ) + arrays = tuple(dpnp.get_usm_ndarray(x) for x in xi) arrays_obj = dpt.meshgrid(*arrays, indexing=indexing) return [ dpnp_array._create_from_usm_ndarray(array_obj) @@ -304,17 +314,13 @@ def ones( def tril(x1, /, *, k=0): """Creates `dpnp_array` as lower triangular part of an input array.""" - array_obj = dpt.tril( - x1.get_array() if isinstance(x1, dpnp_array) else x1, k - ) + array_obj = dpt.tril(dpnp.get_usm_ndarray(x1), k) return dpnp_array(array_obj.shape, buffer=array_obj, order="K") def triu(x1, /, *, k=0): """Creates `dpnp_array` as upper triangular part of an input array.""" - array_obj = dpt.triu( - x1.get_array() if isinstance(x1, dpnp_array) else x1, k - ) + array_obj = dpt.triu(dpnp.get_usm_ndarray(x1), k) return dpnp_array(array_obj.shape, buffer=array_obj, order="K") diff --git a/dpnp/dpnp_iface_arraycreation.py b/dpnp/dpnp_iface_arraycreation.py index 549920066ac0..ab5c6d1ce32a 100644 --- a/dpnp/dpnp_iface_arraycreation.py +++ b/dpnp/dpnp_iface_arraycreation.py @@ -110,7 +110,7 @@ def arange( Returns ------- - arange : :obj:`dpnp.ndarray` + out : dpnp.ndarray The 1-D array containing evenly spaced values. Limitations @@ -151,7 +151,7 @@ def arange( def array( - x, + a, dtype=None, *, copy=True, @@ -168,6 +168,11 @@ def array( For full documentation refer to :obj:`numpy.array`. + Returns + ------- + out : dpnp.ndarray + An array object satisfying the specified requirements. + Limitations ----------- Parameter `subok` is supported only with default value ``False``. @@ -228,7 +233,7 @@ def array( copy = None return dpnp_container.asarray( - x, + a, dtype=dtype, copy=copy, order=order, @@ -239,7 +244,7 @@ def array( def asanyarray( - x, + a, dtype=None, order=None, *, @@ -249,10 +254,15 @@ def asanyarray( sycl_queue=None, ): """ - Convert the input to an ndarray, but pass ndarray subclasses through. + Convert the input to an :class:`dpnp.ndarray`. For full documentation refer to :obj:`numpy.asanyarray`. + Returns + ------- + out : dpnp.ndarray + Array interpretation of `a`. + Limitations ----------- Parameter `like` is supported only with default value ``None``. @@ -286,7 +296,7 @@ def asanyarray( ) return asarray( - x, + a, dtype=dtype, order=order, device=device, @@ -296,7 +306,7 @@ def asanyarray( def asarray( - x, + a, dtype=None, order=None, like=None, @@ -309,6 +319,12 @@ def asarray( For full documentation refer to :obj:`numpy.asarray`. + Returns + ------- + out : dpnp.ndarray + Array interpretation of `a`. No copy is performed if the input + is already an ndarray with matching dtype and order. + Limitations ----------- Parameter `like` is supported only with default value ``None``. @@ -342,7 +358,7 @@ def asarray( ) return dpnp_container.asarray( - x, + a, dtype=dtype, order=order, device=device, @@ -352,13 +368,19 @@ def asarray( def ascontiguousarray( - x, dtype=None, *, like=None, device=None, usm_type=None, sycl_queue=None + a, dtype=None, *, like=None, device=None, usm_type=None, sycl_queue=None ): """ Return a contiguous array (ndim >= 1) in memory (C order). For full documentation refer to :obj:`numpy.ascontiguousarray`. + Returns + ------- + out : dpnp.ndarray + Contiguous array of same shape and content as `a`, with type `dtype` + if specified. + Limitations ----------- Parameter `like` is supported only with default value ``None``. @@ -407,11 +429,11 @@ def ascontiguousarray( ) # at least 1-d array has to be returned - if x.ndim == 0: - x = [x] + if a.ndim == 0: + a = [a] return asarray( - x, + a, dtype=dtype, order="C", device=device, @@ -421,13 +443,18 @@ def ascontiguousarray( def asfortranarray( - x, dtype=None, *, like=None, device=None, usm_type=None, sycl_queue=None + a, dtype=None, *, like=None, device=None, usm_type=None, sycl_queue=None ): """ Return an array (ndim >= 1) laid out in Fortran order in memory. For full documentation refer to :obj:`numpy.asfortranarray`. + Returns + ------- + out : dpnp.ndarray + The input `a` in Fortran, or column-major, order. + Limitations ----------- Parameter `like` is supported only with default value ``None``. @@ -479,11 +506,11 @@ def asfortranarray( ) # at least 1-d array has to be returned - if x.ndim == 0: - x = [x] + if a.ndim == 0: + a = [a] return asarray( - x, + a, dtype=dtype, order="F", device=device, @@ -492,7 +519,7 @@ def asfortranarray( ) -def copy(x1, order="K", subok=False): +def copy(a, order="K", subok=False): """ Return an array copy of the given object. @@ -500,35 +527,53 @@ def copy(x1, order="K", subok=False): Limitations ----------- - Parameter ``order`` is supported only with default value ``"C"``. - Parameter ``subok`` is supported only with default value ``False``. + Parameter `subok` is supported only with default value ``False``. + Otherwise, the function raises `ValueError` exception. + + Returns + ------- + out : dpnp.ndarray + Array interpretation of `a`. + + See Also + -------- + :obj:`dpnp.ndarray.copy` : Preferred method for creating an array copy + + Notes + ----- + This is equivalent to: + + >>> dpnp.array(a, copy=True) Examples -------- + Create an array `x`, with a reference `y` and a copy `z`: + >>> import dpnp as np >>> x = np.array([1, 2, 3]) >>> y = x >>> z = np.copy(x) + + Note that, when we modify `x`, `y` will change, but not `z`: + >>> x[0] = 10 >>> x[0] == y[0] - True + array(True) >>> x[0] == z[0] - False + array(False) """ - x1_desc = dpnp.get_dpnp_descriptor( - x1, copy_when_strides=False, copy_when_nondefault_queue=False - ) - if x1_desc: - if order != "K": - pass - elif subok: - pass - else: - return dpnp_copy(x1_desc).get_pyobj() + if subok is not False: + raise ValueError( + "Keyword argument `subok` is supported only with " + f"default value ``False``, but got {subok}" + ) + + if dpnp.is_supported_array_type(a): + return dpnp_container.copy(a, order=order) - return call_origin(numpy.copy, x1, order, subok) + return array(a, order=order, subok=subok, copy=True) def diag(x1, k=0): diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 60e9893ef69e..43dfdbad71dc 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -209,8 +209,6 @@ tests/third_party/cupy/creation_tests/test_basic.py::TestBasic::test_ones_like_s tests/third_party/cupy/creation_tests/test_basic.py::TestBasic::test_zeros_like_subok tests/third_party/cupy/creation_tests/test_basic.py::TestBasic::test_zeros_strides -tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_copy_order - tests/third_party/cupy/creation_tests/test_matrix.py::TestMatrix::test_diag_construction_from_list tests/third_party/cupy/creation_tests/test_matrix.py::TestMatrix::test_diag_construction_from_tuple tests/third_party/cupy/creation_tests/test_matrix.py::TestMatrix::test_diag_extraction_from_nested_list diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index ade729c95cde..8cc41b3987c7 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -371,8 +371,6 @@ tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_ tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_zero_num_no_endopoint_with_retstep tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_logspace_zero_num -tests/third_party/cupy/creation_tests/test_from_data.py::TestFromData::test_copy_order - tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_1_{axes=None, norm=None, s=(1, None), shape=(3, 4)}::test_fft2 tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_fft2 tests/third_party/cupy/fft_tests/test_fft.py::TestFft2_param_7_{axes=(), norm=None, s=None, shape=(3, 4)}::test_ifft2 diff --git a/tests/test_arraycreation.py b/tests/test_arraycreation.py index f8f6007c2f82..c69cdf61dd88 100644 --- a/tests/test_arraycreation.py +++ b/tests/test_arraycreation.py @@ -1,4 +1,3 @@ -import operator import tempfile from math import prod @@ -10,7 +9,6 @@ assert_allclose, assert_almost_equal, assert_array_equal, - assert_raises, ) import dpnp @@ -31,6 +29,7 @@ pytest.param("asarray", {"like": dpnp.array([1, 5])}), pytest.param("ascontiguousarray", {"like": dpnp.zeros(4)}), pytest.param("asfortranarray", {"like": dpnp.empty((2, 4))}), + pytest.param("copy", {"subok": True}), ], ) def test_array_copy_exception(func, kwargs): diff --git a/tests/test_copy.py b/tests/test_copy.py new file mode 100644 index 000000000000..e2b161432abd --- /dev/null +++ b/tests/test_copy.py @@ -0,0 +1,74 @@ +import copy + +import numpy +import pytest +from numpy.testing import ( + assert_allclose, + assert_equal, +) + +import dpnp + + +class TestCopyOrder: + a = dpnp.arange(24).reshape(2, 1, 3, 4) + b = a.copy(order="F") + c = dpnp.arange(24).reshape(2, 1, 4, 3).swapaxes(2, 3) + + def check_result(self, x, y, c_contig, f_contig): + assert not (x is y) + assert x.flags.c_contiguous == c_contig + assert x.flags.f_contiguous == f_contig + assert_equal(x, y) + + @pytest.mark.parametrize("arr", [a, b, c]) + def test_order_c(self, arr): + res = arr.copy(order="C") + self.check_result(res, arr, c_contig=True, f_contig=False) + + res = dpnp.copy(arr, order="C") + self.check_result(res, arr, c_contig=True, f_contig=False) + + @pytest.mark.parametrize("arr", [a, b, c]) + def test_order_f(self, arr): + res = arr.copy(order="F") + self.check_result(res, arr, c_contig=False, f_contig=True) + + res = dpnp.copy(arr, order="F") + self.check_result(res, arr, c_contig=False, f_contig=True) + + @pytest.mark.parametrize("arr", [a, b, c]) + def test_order_k(self, arr): + res = arr.copy(order="K") + self.check_result( + res, + arr, + c_contig=arr.flags.c_contiguous, + f_contig=arr.flags.f_contiguous, + ) + + res = dpnp.copy(arr, order="K") + self.check_result( + res, + arr, + c_contig=arr.flags.c_contiguous, + f_contig=arr.flags.f_contiguous, + ) + + res = copy.copy(arr) + self.check_result( + res, + arr, + c_contig=arr.flags.c_contiguous, + f_contig=arr.flags.f_contiguous, + ) + + +@pytest.mark.parametrize( + "val", + [3.7, numpy.arange(7), [2, 7, 3.6], (-3, 4), range(4)], + ids=["scalar", "numpy.array", "list", "tuple", "range"], +) +def test_copy_not_dpnp_array(val): + a = dpnp.copy(val) + assert_allclose(a, val) diff --git a/tests/third_party/cupy/core_tests/test_elementwise.py b/tests/third_party/cupy/core_tests/test_elementwise.py new file mode 100644 index 000000000000..354a56fc5c42 --- /dev/null +++ b/tests/third_party/cupy/core_tests/test_elementwise.py @@ -0,0 +1,143 @@ +import unittest + +import numpy +import pytest + +import dpnp as cupy +from tests.third_party.cupy import testing + + +class TestElementwise(unittest.TestCase): + def check_copy(self, dtype, src_id, dst_id): + with cuda.Device(src_id): + src = testing.shaped_arange((2, 3, 4), dtype=dtype) + with cuda.Device(dst_id): + dst = cupy.empty((2, 3, 4), dtype=dtype) + _core.elementwise_copy(src, dst) + testing.assert_allclose(src, dst) + + @pytest.mark.skip("`device` argument isn't supported") + @testing.for_all_dtypes() + def test_copy(self, dtype): + device_id = cuda.Device().id + self.check_copy(dtype, device_id, device_id) + + @pytest.mark.skip("`device` argument isn't supported") + @testing.for_all_dtypes() + def test_copy_multigpu_nopeer(self, dtype): + if cuda.runtime.deviceCanAccessPeer(0, 1) == 1: + pytest.skip("peer access is available") + with self.assertRaises(ValueError): + self.check_copy(dtype, 0, 1) + + @pytest.mark.skip("`device` argument isn't supported") + @testing.for_all_dtypes() + def test_copy_multigpu_peer(self, dtype): + if cuda.runtime.deviceCanAccessPeer(0, 1) != 1: + pytest.skip("peer access is unavailable") + with pytest.warns(cupy._util.PerformanceWarning): + self.check_copy(dtype, 0, 1) + + @testing.for_orders("CFAK") + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose() + def test_copy_zero_sized_array1(self, xp, dtype, order): + src = xp.empty((0,), dtype=dtype) + res = xp.copy(src, order=order) + assert src is not res + return res + + @testing.for_orders("CFAK") + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose() + def test_copy_zero_sized_array2(self, xp, dtype, order): + src = xp.empty((1, 0, 2), dtype=dtype) + res = xp.copy(src, order=order) + assert src is not res + return res + + @testing.for_orders("CFAK") + def test_copy_orders(self, order): + a = cupy.empty((2, 3, 4)) + b = cupy.copy(a, order) + + a_cpu = numpy.empty((2, 3, 4)) + b_cpu = numpy.copy(a_cpu, order) + + assert b.strides == tuple(x / b_cpu.itemsize for x in b_cpu.strides) + + +@pytest.mark.skip("`ElementwiseKernel` function isn't supported") +class TestElementwiseInvalidShape(unittest.TestCase): + def test_invalid_shape(self): + with self.assertRaisesRegex(ValueError, "Out shape is mismatched"): + f = cupy.ElementwiseKernel("T x", "T y", "y += x") + x = cupy.arange(12).reshape(3, 4) + y = cupy.arange(4) + f(x, y) + + +@pytest.mark.skip("`ElementwiseKernel` function isn't supported") +class TestElementwiseInvalidArgument(unittest.TestCase): + def test_invalid_kernel_name(self): + with self.assertRaisesRegex(ValueError, "Invalid kernel name"): + cupy.ElementwiseKernel("T x", "", "", "1") + + +@pytest.mark.skip("`iinfo` function isn't supported") +class TestElementwiseType(unittest.TestCase): + @testing.for_int_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_large_int_upper_1(self, xp, dtype): + a = xp.array([0], dtype=xp.int8) + b = xp.iinfo(dtype).max + return a + b + + @testing.for_int_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_large_int_upper_2(self, xp, dtype): + a = xp.array([1], dtype=xp.int8) + b = xp.iinfo(dtype).max - 1 + return a + b + + @testing.for_int_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_large_int_upper_3(self, xp, dtype): + a = xp.array([xp.iinfo(dtype).max], dtype=dtype) + b = xp.int8(0) + return a + b + + @testing.for_int_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_large_int_upper_4(self, xp, dtype): + a = xp.array([xp.iinfo(dtype).max - 1], dtype=dtype) + b = xp.int8(1) + return a + b + + @testing.for_int_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_large_int_lower_1(self, xp, dtype): + a = xp.array([0], dtype=xp.int8) + b = xp.iinfo(dtype).min + return a + b + + @testing.for_int_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_large_int_lower_2(self, xp, dtype): + a = xp.array([-1], dtype=xp.int8) + b = xp.iinfo(dtype).min + 1 + return a + b + + @testing.for_int_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_large_int_lower_3(self, xp, dtype): + a = xp.array([xp.iinfo(dtype).min], dtype=dtype) + b = xp.int8(0) + return a + b + + @testing.for_int_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_large_int_lower_4(self, xp, dtype): + a = xp.array([xp.iinfo(dtype).min + 1], dtype=dtype) + b = xp.int8(-1) + return a + b diff --git a/tests/third_party/cupy/creation_tests/test_from_data.py b/tests/third_party/cupy/creation_tests/test_from_data.py index be606d1840f3..ae29f393e9dc 100644 --- a/tests/third_party/cupy/creation_tests/test_from_data.py +++ b/tests/third_party/cupy/creation_tests/test_from_data.py @@ -501,7 +501,6 @@ def test_asarray_from_big_endian(self, xp, dtype): # happens to work before the change in #5828 return b + b - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_CF_orders() @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() From 23f24accab0a72b1f8b2251087963f93aea94935 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Sun, 27 Aug 2023 22:59:35 +0200 Subject: [PATCH 2/3] Renamed test files, since coverage tool requires unique names --- dpnp/dpnp_iface_arraycreation.py | 2 +- .../{test_elementwise.py => test_core_elementwise.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/third_party/cupy/core_tests/{test_elementwise.py => test_core_elementwise.py} (100%) diff --git a/dpnp/dpnp_iface_arraycreation.py b/dpnp/dpnp_iface_arraycreation.py index ab5c6d1ce32a..22e7c1e0d9d1 100644 --- a/dpnp/dpnp_iface_arraycreation.py +++ b/dpnp/dpnp_iface_arraycreation.py @@ -322,7 +322,7 @@ def asarray( Returns ------- out : dpnp.ndarray - Array interpretation of `a`. No copy is performed if the input + Array interpretation of `a`. No copy is performed if the input is already an ndarray with matching dtype and order. Limitations diff --git a/tests/third_party/cupy/core_tests/test_elementwise.py b/tests/third_party/cupy/core_tests/test_core_elementwise.py similarity index 100% rename from tests/third_party/cupy/core_tests/test_elementwise.py rename to tests/third_party/cupy/core_tests/test_core_elementwise.py From 778d3dba5627e13c396f3dbe7dc61caa030e7958 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Mon, 28 Aug 2023 11:34:33 +0200 Subject: [PATCH 3/3] dpctl accepted lowercase order --- dpnp/backend/kernels/dpnp_krnl_elemwise.cpp | 2 -- dpnp/backend/kernels/dpnp_krnl_logic.cpp | 2 -- dpnp/dpnp_container.py | 2 -- 3 files changed, 6 deletions(-) diff --git a/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp b/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp index 6fc494b43849..39026554d4b8 100644 --- a/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp @@ -1326,8 +1326,6 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap) { \ constexpr size_t lws = 64; \ constexpr unsigned int vec_sz = 8; \ - constexpr sycl::access::address_space global_space = \ - sycl::access::address_space::global_space; \ \ auto gws_range = sycl::range<1>( \ ((result_size + lws * vec_sz - 1) / (lws * vec_sz)) * \ diff --git a/dpnp/backend/kernels/dpnp_krnl_logic.cpp b/dpnp/backend/kernels/dpnp_krnl_logic.cpp index ac8f7ca4560b..818feeb43c85 100644 --- a/dpnp/backend/kernels/dpnp_krnl_logic.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_logic.cpp @@ -521,8 +521,6 @@ DPCTLSyclEventRef (*dpnp_any_ext_c)(DPCTLSyclQueueRef, else { \ constexpr size_t lws = 64; \ constexpr unsigned int vec_sz = 8; \ - constexpr sycl::access::address_space global_space = \ - sycl::access::address_space::global_space; \ \ auto gws_range = sycl::range<1>( \ ((result_size + lws * vec_sz - 1) / (lws * vec_sz)) * lws); \ diff --git a/dpnp/dpnp_container.py b/dpnp/dpnp_container.py index f7e8cc355d82..faf14c3e97bb 100644 --- a/dpnp/dpnp_container.py +++ b/dpnp/dpnp_container.py @@ -140,8 +140,6 @@ def copy(x1, /, *, order="K"): """Creates `dpnp_array` as a copy of given instance of input array.""" if order is None: order = "K" - else: - order = order.upper() array_obj = dpt.copy(dpnp.get_usm_ndarray(x1), order=order) return dpnp_array(array_obj.shape, buffer=array_obj, order="K")