23
23
default_device_int_type ,
24
24
)
25
25
26
- __array_api_version__ = "2022.12"
27
-
28
26
29
27
def _isdtype_impl (dtype , kind ):
30
- if isinstance (kind , dpt .dtype ):
31
- return dtype == kind
32
-
33
- elif isinstance (kind , str ):
28
+ if isinstance (kind , str ):
34
29
if kind == "bool" :
35
30
return dtype .kind == "b"
36
31
elif kind == "signed integer" :
@@ -54,15 +49,21 @@ def _isdtype_impl(dtype, kind):
54
49
raise TypeError (f"Unsupported data type kind: { kind } " )
55
50
56
51
57
- class __array_namespace_info__ :
52
+ __array_api_version__ = "2022.12"
53
+
54
+
55
+ class Info :
56
+ """
57
+ namespace returned by `__array_namespace_info__()`
58
+ """
59
+
58
60
def __init__ (self ):
59
61
self ._capabilities = {
60
62
"boolean_indexing" : True ,
61
63
"data_dependent_shapes" : True ,
62
64
}
63
65
self ._all_dtypes = {
64
66
"bool" : dpt .bool ,
65
- "float16" : dpt .float16 ,
66
67
"float32" : dpt .float32 ,
67
68
"float64" : dpt .float64 ,
68
69
"complex64" : dpt .complex64 ,
@@ -78,41 +79,129 @@ def __init__(self):
78
79
}
79
80
80
81
def capabilities (self ):
82
+ """
83
+ Returns a dictionary of `dpctl`'s capabilities.
84
+
85
+ Returns:
86
+ dict:
87
+ dictionary of `dpctl`'s capabilities
88
+ - `boolean_indexing`: bool
89
+ - `data_dependent_shapes`: bool
90
+ """
81
91
return self ._capabilities .copy ()
82
92
83
93
def default_device (self ):
94
+ """
95
+ Returns the default SYCL device.
96
+ """
84
97
return dpctl .select_default_device ()
85
98
86
99
def default_dtypes (self , device = None ):
100
+ """
101
+ Returns a dictionary of default data types for `device`.
102
+
103
+ Args:
104
+ device (Optional[dpctl.SyclDevice, dpctl.SyclQueue,
105
+ dpctl.tensor.Device]):
106
+ array API concept of device used in getting default data types.
107
+ `device` can be `None` (in which case the default device is
108
+ used), an instance of :class:`dpctl.SyclDevice` corresponding
109
+ to a non-partitioned SYCL device, an instance of
110
+ :class:`dpctl.SyclQueue`, or a `Device` object returned by
111
+ :attr:`dpctl.tensor.usm_array.device`. Default: `None`.
112
+
113
+ Returns:
114
+ dict:
115
+ a dictionary of default data types for `device`
116
+ - `real floating`: dtype
117
+ - `complex floating`: dtype
118
+ - `integral`: dtype
119
+ - `indexing`: dtype
120
+ """
87
121
if device is None :
88
122
device = dpctl .select_default_device ()
123
+ elif isinstance (device , dpt .Device ):
124
+ device = device .sycl_device
89
125
return {
90
- "real floating" : default_device_fp_type (device ),
91
- "complex floating" : default_device_complex_type ,
92
- "integral" : default_device_int_type (device ),
93
- "indexing" : default_device_index_type (device ),
126
+ "real floating" : dpt . dtype ( default_device_fp_type (device ) ),
127
+ "complex floating" : dpt . dtype ( default_device_complex_type ( device )) ,
128
+ "integral" : dpt . dtype ( default_device_int_type (device ) ),
129
+ "indexing" : dpt . dtype ( default_device_index_type (device ) ),
94
130
}
95
131
96
132
def dtypes (self , device = None , kind = None ):
133
+ """
134
+ Returns a dictionary of all Array API data types of a specified `kind`
135
+ supported by `device`
136
+
137
+ This dictionary only includes data types supported by the array API.
138
+
139
+ See [array API](array_api).
140
+
141
+ [array_api]: https://data-apis.org/array-api/latest/
142
+
143
+ Args:
144
+ device (Optional[dpctl.SyclDevice, dpctl.SyclQueue,
145
+ dpctl.tensor.Device, str]):
146
+ array API concept of device used in getting default data types.
147
+ `device` can be `None` (in which case the default device is
148
+ used), an instance of :class:`dpctl.SyclDevice` corresponding
149
+ to a non-partitioned SYCL device, an instance of
150
+ :class:`dpctl.SyclQueue`, or a `Device` object returned by
151
+ :attr:`dpctl.tensor.usm_array.device`. Default: `None`.
152
+
153
+ kind (Optional[str, Tuple[str, ...]]):
154
+ data type kind.
155
+ - if `kind` is `None`, returns a dictionary of all data types
156
+ supported by `device`
157
+ - if `kind` is a string, returns a dictionary containing the
158
+ data types belonging to the data type kind specified.
159
+ Supports:
160
+ - "bool"
161
+ - "signed integer"
162
+ - "unsigned integer"
163
+ - "integral"
164
+ - "real floating"
165
+ - "complex floating"
166
+ - "numeric"
167
+ - if `kind` is a tuple, the tuple represents a union of `kind`
168
+ strings, and returns a dictionary containing data types
169
+ corresponding to the-specified union.
170
+ Default: `None`.
171
+
172
+ Returns:
173
+ dict:
174
+ a dictionary of the supported data types of the specified `kind`
175
+ """
97
176
if device is None :
98
177
device = dpctl .select_default_device ()
99
- ignored_types = []
100
- if not device .has_aspect_fp16 :
101
- ignored_types .append ("float16" )
102
- if not device .has_aspect_fp64 :
103
- ignored_types .append ("float64" )
178
+ elif isinstance (device , dpt .Device ):
179
+ device = device .sycl_device
180
+ _fp64 = device .has_aspect_fp64
104
181
if kind is None :
105
182
return {
106
183
key : val
107
184
for key , val in self ._all_dtypes .items ()
108
- if key not in ignored_types
185
+ if ( key != "float64" or _fp64 )
109
186
}
110
187
else :
111
188
return {
112
189
key : val
113
190
for key , val in self ._all_dtypes .items ()
114
- if key not in ignored_types and _isdtype_impl (val , kind )
191
+ if ( key != "float64" or _fp64 ) and _isdtype_impl (val , kind )
115
192
}
116
193
117
194
def devices (self ):
195
+ """
196
+ Returns a list of supported devices.
197
+ """
118
198
return dpctl .get_devices ()
199
+
200
+
201
+ def __array_namespace_info__ ():
202
+ """__array_namespace_info__()
203
+
204
+ Returns a namespace with Array API namespace inspection utilities.
205
+
206
+ """
207
+ return Info ()
0 commit comments