diff --git a/dpctl/_sycl_queue_manager.pxd b/dpctl/_sycl_queue_manager.pxd index 9f3d0b6799..8e8bec95d2 100644 --- a/dpctl/_sycl_queue_manager.pxd +++ b/dpctl/_sycl_queue_manager.pxd @@ -17,9 +17,12 @@ # distutils: language = c++ # cython: language_level=3 +from ._sycl_device cimport SyclDevice from ._sycl_queue cimport SyclQueue cpdef SyclQueue get_current_queue() cpdef get_current_device_type () cpdef get_current_backend() + +cpdef object get_device_cached_queue(object) diff --git a/dpctl/_sycl_queue_manager.pyx b/dpctl/_sycl_queue_manager.pyx index 65887c387a..9f92f52092 100644 --- a/dpctl/_sycl_queue_manager.pyx +++ b/dpctl/_sycl_queue_manager.pyx @@ -20,6 +20,7 @@ import logging from contextlib import ExitStack, contextmanager +from contextvars import ContextVar from .enum_types import backend_type, device_type @@ -35,6 +36,7 @@ from ._backend cimport ( # noqa: E211 _device_type, ) from ._sycl_context cimport SyclContext +from ._sycl_device cimport SyclDevice __all__ = [ "device_context", @@ -44,6 +46,7 @@ __all__ = [ "get_num_activated_queues", "is_in_device_context", "set_global_queue", + "_global_device_queue_cache", ] _logger = logging.getLogger(__name__) @@ -291,3 +294,45 @@ def device_context(arg): _mgr._remove_current_queue() else: _logger.debug("No queue was created so nothing to do") + + +cdef class _DeviceDefaultQueueCache: + cdef dict __device_queue_map__ + + def __cinit__(self): + self.__device_queue_map__ = dict() + + def get_or_create(self, key): + """Return instance of SyclQueue and indicator if cache has been modified""" + if isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], SyclContext) and isinstance(key[1], SyclDevice): + ctx_dev = key + q = None + elif isinstance(key, SyclDevice): + q = SyclQueue(key) + ctx_dev = q.sycl_context, key + else: + raise TypeError + if ctx_dev in self.__device_queue_map__: + return self.__device_queue_map__[ctx_dev], False + if q is None: q = SyclQueue(*ctx_dev) + self.__device_queue_map__[ctx_dev] = q + return q, True + + cdef _update_map(self, dev_queue_map): + self.__device_queue_map__.update(dev_queue_map) + + def __copy__(self): + cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__(_DeviceDefaultQueueCache) + _copy._update_map(self.__device_queue_map__) + return _copy + + +_global_device_queue_cache = ContextVar('global_device_queue_cache', default=_DeviceDefaultQueueCache()) + + +cpdef object get_device_cached_queue(object key): + """Get cached queue associated with given device""" + _cache = _global_device_queue_cache.get() + q_, changed_ = _cache.get_or_create(key) + if changed_: _global_device_queue_cache.set(_cache) + return q_ diff --git a/dpctl/memory/_memory.pyx b/dpctl/memory/_memory.pyx index b650dca33c..044a5b55a1 100644 --- a/dpctl/memory/_memory.pyx +++ b/dpctl/memory/_memory.pyx @@ -61,6 +61,7 @@ from dpctl._backend cimport ( # noqa: E211 from .._sycl_context cimport SyclContext from .._sycl_device cimport SyclDevice from .._sycl_queue cimport SyclQueue +from .._sycl_queue_manager cimport get_device_cached_queue import collections import numbers @@ -150,7 +151,7 @@ cdef class _Memory: if (nbytes > 0): if queue is None: - queue = dpctl.SyclQueue() + queue = get_device_cached_queue(dpctl.SyclDevice()) QRef = queue.get_queue_ref() if (ptr_type == b"shared"): diff --git a/dpctl/tensor/_device.py b/dpctl/tensor/_device.py index 96185e507d..07a3e41e09 100644 --- a/dpctl/tensor/_device.py +++ b/dpctl/tensor/_device.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import dpctl +from dpctl._sycl_queue_manager import get_device_cached_queue __doc__ = "Implementation of array API mandated Device class" @@ -60,9 +61,7 @@ def create_device(cls, dev): elif isinstance(dev, dpctl.SyclDevice): par = dev.parent_device if par is None: - if dev not in cls.__device_queue_map__: - cls.__device_queue_map__[dev] = dpctl.SyclQueue(dev) - obj.sycl_queue_ = cls.__device_queue_map__[dev] + obj.sycl_queue_ = get_device_cached_queue(dev) else: raise ValueError( f"Using non-root device {dev} to specify offloading " @@ -74,9 +73,7 @@ def create_device(cls, dev): _dev = dpctl.SyclDevice() else: _dev = dpctl.SyclDevice(dev) - if _dev not in cls.__device_queue_map__: - cls.__device_queue_map__[_dev] = dpctl.SyclQueue(_dev) - obj.sycl_queue_ = cls.__device_queue_map__[_dev] + obj.sycl_queue_ = get_device_cached_queue(_dev) return obj @property diff --git a/dpctl/tensor/_dlpack.pyx b/dpctl/tensor/_dlpack.pyx index c93bfb1c1c..1d4d72df2d 100644 --- a/dpctl/tensor/_dlpack.pyx +++ b/dpctl/tensor/_dlpack.pyx @@ -24,6 +24,7 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint64_t cimport dpctl as c_dpctl cimport dpctl.memory as c_dpmem +from dpctl._sycl_queue_manager cimport get_device_cached_queue from .._backend cimport ( DPCTLDevice_Delete, @@ -344,12 +345,12 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +: if _IS_LINUX: default_context = root_device.sycl_platform.default_context else: - default_context = dpctl.SyclQueue(root_device).sycl_context + default_context = get_device_cached_queue(root_device).sycl_context except RuntimeError: - default_context = dpctl.SyclQueue(root_device).sycl_context + default_context = get_device_cached_queue(root_device).sycl_context if dlm_tensor.dl_tensor.data is NULL: usm_type = b"device" - q = dpctl.SyclQueue(default_context, root_device) + q = get_device_cached_queue((default_context, root_device,)) else: usm_type = c_dpmem._Memory.get_pointer_type( dlm_tensor.dl_tensor.data, @@ -364,7 +365,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +: dlm_tensor.dl_tensor.data, default_context ) - q = dpctl.SyclQueue(default_context, alloc_device) + q = get_device_cached_queue((default_context, alloc_device,)) if dlm_tensor.dl_tensor.dtype.bits % 8: raise BufferError( "Can not import DLPack tensor whose element's " diff --git a/dpctl/tests/test_sycl_queue_manager.py b/dpctl/tests/test_sycl_queue_manager.py index c39e2d6b30..d694650b43 100644 --- a/dpctl/tests/test_sycl_queue_manager.py +++ b/dpctl/tests/test_sycl_queue_manager.py @@ -226,3 +226,22 @@ def test_nested_context_factory_exception_if_wrong_factory( with _register_nested_context_factory(factory): with dpctl.device_context("opencl:cpu:0"): pass + + +def test__DeviceDefaultQueueCache(): + import copy + + from dpctl._sycl_queue_manager import _global_device_queue_cache as cache + from dpctl._sycl_queue_manager import get_device_cached_queue + + try: + d = dpctl.SyclDevice() + except dpctl.SyclDeviceCreationError: + pytest.skip("Could not create default device") + + q1 = get_device_cached_queue(d) + cache_copy = copy.copy(cache.get()) + q2, changed = cache_copy.get_or_create(d) + + assert not changed + assert q1 == q2 diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index baf870c036..e442cff43f 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -81,6 +81,28 @@ def test_dlpack_exporter(typestr, usm_type): assert caps_fn(caps2, b"dltensor") +def test_dlpack_exporter_empty(typestr, usm_type): + caps_fn = ctypes.pythonapi.PyCapsule_IsValid + caps_fn.restype = bool + caps_fn.argtypes = [ctypes.py_object, ctypes.c_char_p] + sycl_dev = dpctl.select_default_device() + skip_if_dtype_not_supported(typestr, sycl_dev) + X = dpt.empty((0,), dtype=typestr, usm_type=usm_type, device=sycl_dev) + caps = X.__dlpack__() + assert caps_fn(caps, b"dltensor") + Y = dpt.empty( + ( + 1, + 0, + ), + dtype=typestr, + usm_type=usm_type, + device=sycl_dev, + ) + caps = Y.__dlpack__() + assert caps_fn(caps, b"dltensor") + + def test_dlpack_exporter_stream(): try: q1 = dpctl.SyclQueue()