@@ -303,13 +303,22 @@ cdef class _DeviceDefaultQueueCache:
303
303
self .__device_queue_map__ = dict ()
304
304
305
305
def get_or_create (self , key ):
306
- """ Return instance of SyclQueue and indicator if cache has been modified"""
307
- if isinstance (key, tuple ) and len (key) == 2 and isinstance (key[0 ], SyclContext) and isinstance (key[1 ], SyclDevice):
306
+ """ Return instance of SyclQueue and indicator if cache
307
+ has been modified"""
308
+ if (
309
+ isinstance (key, tuple )
310
+ and len (key) == 2
311
+ and isinstance (key[0 ], SyclContext)
312
+ and isinstance (key[1 ], SyclDevice)
313
+ ):
308
314
ctx_dev = key
309
315
q = None
310
316
elif isinstance (key, SyclDevice):
311
317
q = SyclQueue(key)
312
318
ctx_dev = q.sycl_context, key
319
+ elif isinstance (key, str ):
320
+ q = SyclQueue(key)
321
+ ctx_dev = q.sycl_context, q.sycl_device
313
322
else :
314
323
raise TypeError
315
324
if ctx_dev in self .__device_queue_map__:
@@ -322,12 +331,16 @@ cdef class _DeviceDefaultQueueCache:
322
331
self .__device_queue_map__.update(dev_queue_map)
323
332
324
333
def __copy__ (self ):
325
- cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__ (_DeviceDefaultQueueCache)
334
+ cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__ (
335
+ _DeviceDefaultQueueCache)
326
336
_copy._update_map(self .__device_queue_map__)
327
337
return _copy
328
338
329
339
330
- _global_device_queue_cache = ContextVar(' global_device_queue_cache' , default = _DeviceDefaultQueueCache())
340
+ _global_device_queue_cache = ContextVar(
341
+ ' global_device_queue_cache' ,
342
+ default = _DeviceDefaultQueueCache()
343
+ )
331
344
332
345
333
346
cpdef object get_device_cached_queue(object key):
0 commit comments