@@ -31,46 +31,37 @@ def __init__(
31
31
aligned = True ,
32
32
addrspace = address_space .GLOBAL ,
33
33
):
34
+ if not isinstance (device , str ):
35
+ raise TypeError (
36
+ "The device keyword arg should be a str object specifying "
37
+ "a SYCL filter selector"
38
+ )
39
+
40
+ if not isinstance (queue , dpctl .SyclQueue ) and queue is not None :
41
+ raise TypeError (
42
+ "The queue keyword arg should be a dpctl.SyclQueue object or None"
43
+ )
44
+
34
45
self .usm_type = usm_type
35
46
self .addrspace = addrspace
36
47
37
- if queue is not None and device != "unknown" :
38
- if not isinstance (device , str ):
39
- raise TypeError (
40
- "The device keyword arg should be a str object specifying "
41
- "a SYCL filter selector"
42
- )
43
- if not isinstance (queue , dpctl .SyclQueue ):
44
- raise TypeError (
45
- "The queue keyword arg should be a dpctl.SyclQueue object"
46
- )
47
- d1 = queue .sycl_device
48
- d2 = dpctl .SyclDevice (device )
49
- if d1 != d2 :
50
- raise TypeError (
51
- "The queue keyword arg and the device keyword arg specify "
52
- "different SYCL devices"
53
- )
54
- self .queue = queue
55
- self .device = device
56
- elif queue is None and device != "unknown" :
57
- if not isinstance (device , str ):
58
- raise TypeError (
59
- "The device keyword arg should be a str object specifying "
60
- "a SYCL filter selector"
61
- )
62
- self .queue = dpctl .SyclQueue (device )
63
- self .device = self .queue .sycl_device .filter_string
64
- elif queue is not None and device == "unknown" :
65
- if not isinstance (queue , dpctl .SyclQueue ):
66
- raise TypeError (
67
- "The queue keyword arg should be a dpctl.SyclQueue object"
68
- )
69
- self .device = self .queue .sycl_device .filter_string
48
+ if device == "unknown" :
49
+ device = None
50
+
51
+ if queue is not None and device is not None :
52
+ raise TypeError (
53
+ "'queue' and 'device' keywords can not be both specified"
54
+ )
55
+
56
+ if queue is not None :
70
57
self .queue = queue
71
58
else :
72
- self .queue = dpctl .SyclQueue ()
73
- self .device = self .queue .sycl_device .filter_string
59
+ if device is None :
60
+ device = dpctl .SyclDevice ()
61
+
62
+ self .queue = dpctl .get_device_cached_queue (device )
63
+
64
+ self .device = self .queue .sycl_device .filter_string
74
65
75
66
if not dtype :
76
67
dummy_tensor = dpctl .tensor .empty (
0 commit comments