diff --git a/dpctl/apis/include/dpctl4pybind11.hpp b/dpctl/apis/include/dpctl4pybind11.hpp index 7a29754fcc..1434da1f32 100644 --- a/dpctl/apis/include/dpctl4pybind11.hpp +++ b/dpctl/apis/include/dpctl4pybind11.hpp @@ -965,6 +965,15 @@ class usm_ndarray : public py::object return *(reinterpret_cast(QRef)); } + sycl::device get_device() const + { + PyUSMArrayObject *raw_ar = usm_array_ptr(); + + auto const &api = ::dpctl::detail::dpctl_capi::get(); + DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar); + return reinterpret_cast(QRef)->get_device(); + } + int get_typenum() const { PyUSMArrayObject *raw_ar = usm_array_ptr();