diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 565a11dec7..99d39de536 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -586,7 +586,7 @@ def _nonzero_impl(ary): mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C" ) mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q) - indexes_dt = ti.default_device_int_type(exec_q.sycl_device) + indexes_dt = ti.default_device_index_type(exec_q.sycl_device) indexes = dpt.empty( (ary.ndim, mask_count), dtype=indexes_dt, diff --git a/dpctl/tensor/libtensor/source/device_support_queries.cpp b/dpctl/tensor/libtensor/source/device_support_queries.cpp index 16ae43ba97..946c36ad26 100644 --- a/dpctl/tensor/libtensor/source/device_support_queries.cpp +++ b/dpctl/tensor/libtensor/source/device_support_queries.cpp @@ -71,6 +71,11 @@ std::string _default_device_bool_type(sycl::device) return "b1"; } +std::string _default_device_index_type(sycl::device) +{ + return "i8"; +} + sycl::device _extract_device(py::object arg) { auto const &api = dpctl::detail::dpctl_capi::get(); @@ -115,6 +120,12 @@ std::string default_device_complex_type(py::object arg) return _default_device_complex_type(d); } +std::string default_device_index_type(py::object arg) +{ + sycl::device d = _extract_device(arg); + return _default_device_index_type(d); +} + } // namespace py_internal } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/device_support_queries.hpp b/dpctl/tensor/libtensor/source/device_support_queries.hpp index a54835fc75..6626b3502a 100644 --- a/dpctl/tensor/libtensor/source/device_support_queries.hpp +++ b/dpctl/tensor/libtensor/source/device_support_queries.hpp @@ -41,6 +41,7 @@ extern std::string default_device_fp_type(py::object); extern std::string default_device_int_type(py::object); extern std::string default_device_bool_type(py::object); extern std::string default_device_complex_type(py::object); +extern std::string default_device_index_type(py::object); } // namespace py_internal } // namespace tensor diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 1833c2d770..9b4ba6cdad 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -297,9 +297,13 @@ PYBIND11_MODULE(_tensor_impl, m) m.def("default_device_complex_type", dpctl::tensor::py_internal::default_device_complex_type, - "Gives default complex floating point type support by device.", + "Gives default complex floating point type supported by device.", py::arg("dev")); + m.def("default_device_index_type", + dpctl::tensor::py_internal::default_device_index_type, + "Gives default index type supported by device.", py::arg("dev")); + auto tril_fn = [](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, py::ssize_t k, sycl::queue exec_q, diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 87d89a1b8d..9d166226e7 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -22,6 +22,7 @@ import dpctl import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti from dpctl.utils import ExecutionPlacementError _all_dtypes = [ @@ -1353,7 +1354,7 @@ def test_nonzero_dtype(): x = dpt.ones((3, 4)) idx, idy = dpt.nonzero(x) # create array using device's - # default integral data type - ref = dpt.arange(8) - assert idx.dtype == ref.dtype - assert idy.dtype == ref.dtype + # default index data type + index_dt = dpt.dtype(ti.default_device_index_type(x.sycl_queue)) + assert idx.dtype == index_dt + assert idy.dtype == index_dt