From e28bd881e23778b630dbb97d1a6fa4713af10556 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 11 Aug 2023 08:19:24 -0500 Subject: [PATCH 1/3] Add default_device_index_type(queue_or_dev) utility This returns default index type for give device. Since all devices are 64-bit devices, it always returns "i8". --- .../libtensor/source/device_support_queries.cpp | 11 +++++++++++ .../libtensor/source/device_support_queries.hpp | 1 + dpctl/tensor/libtensor/source/tensor_py.cpp | 6 +++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/libtensor/source/device_support_queries.cpp b/dpctl/tensor/libtensor/source/device_support_queries.cpp index 16ae43ba97..946c36ad26 100644 --- a/dpctl/tensor/libtensor/source/device_support_queries.cpp +++ b/dpctl/tensor/libtensor/source/device_support_queries.cpp @@ -71,6 +71,11 @@ std::string _default_device_bool_type(sycl::device) return "b1"; } +std::string _default_device_index_type(sycl::device) +{ + return "i8"; +} + sycl::device _extract_device(py::object arg) { auto const &api = dpctl::detail::dpctl_capi::get(); @@ -115,6 +120,12 @@ std::string default_device_complex_type(py::object arg) return _default_device_complex_type(d); } +std::string default_device_index_type(py::object arg) +{ + sycl::device d = _extract_device(arg); + return _default_device_index_type(d); +} + } // namespace py_internal } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/device_support_queries.hpp b/dpctl/tensor/libtensor/source/device_support_queries.hpp index a54835fc75..6626b3502a 100644 --- a/dpctl/tensor/libtensor/source/device_support_queries.hpp +++ b/dpctl/tensor/libtensor/source/device_support_queries.hpp @@ -41,6 +41,7 @@ extern std::string default_device_fp_type(py::object); extern std::string default_device_int_type(py::object); extern std::string default_device_bool_type(py::object); extern std::string default_device_complex_type(py::object); +extern std::string default_device_index_type(py::object); } // namespace py_internal } // namespace tensor diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 1833c2d770..9b4ba6cdad 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -297,9 +297,13 @@ PYBIND11_MODULE(_tensor_impl, m) m.def("default_device_complex_type", dpctl::tensor::py_internal::default_device_complex_type, - "Gives default complex floating point type support by device.", + "Gives default complex floating point type supported by device.", py::arg("dev")); + m.def("default_device_index_type", + dpctl::tensor::py_internal::default_device_index_type, + "Gives default index type supported by device.", py::arg("dev")); + auto tril_fn = [](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, py::ssize_t k, sycl::queue exec_q, From bf22a284dadcba363dc92e18adaae4bb742bd73b Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 11 Aug 2023 08:20:34 -0500 Subject: [PATCH 2/3] dpt.nonzero should use default index type Closes gh-1335. The issue was caused by nonzero using default integral data type, not default index data type. --- dpctl/tensor/_copy_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 565a11dec7..99d39de536 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -586,7 +586,7 @@ def _nonzero_impl(ary): mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C" ) mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q) - indexes_dt = ti.default_device_int_type(exec_q.sycl_device) + indexes_dt = ti.default_device_index_type(exec_q.sycl_device) indexes = dpt.empty( (ary.ndim, mask_count), dtype=indexes_dt, From e70891ba3c836a466cd1b7155061ffacd3f867f5 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 11 Aug 2023 08:21:47 -0500 Subject: [PATCH 3/3] Adjusted test_nonzero_dtype to use default index type as reference --- dpctl/tests/test_usm_ndarray_indexing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 87d89a1b8d..9d166226e7 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -22,6 +22,7 @@ import dpctl import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti from dpctl.utils import ExecutionPlacementError _all_dtypes = [ @@ -1353,7 +1354,7 @@ def test_nonzero_dtype(): x = dpt.ones((3, 4)) idx, idy = dpt.nonzero(x) # create array using device's - # default integral data type - ref = dpt.arange(8) - assert idx.dtype == ref.dtype - assert idy.dtype == ref.dtype + # default index data type + index_dt = dpt.dtype(ti.default_device_index_type(x.sycl_queue)) + assert idx.dtype == index_dt + assert idy.dtype == index_dt