Skip to content

Commit 68875dc

Browse files
committed
Adds __array_namespace_info__ docstrings
Disallows dtypes for `kind` kwarg in __array_namespace_info__().dtypes Removes `float16` from dtypes listed by __array_namespace_info__ as per spec Permits dpctl.tensor.Device objects in device keyword arguments in array API inspection utilities
1 parent 91b4aaf commit 68875dc

File tree

1 file changed

+108
-19
lines changed

1 file changed

+108
-19
lines changed

dpctl/tensor/_array_api.py

+108-19
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,9 @@
2323
default_device_int_type,
2424
)
2525

26-
__array_api_version__ = "2022.12"
27-
2826

2927
def _isdtype_impl(dtype, kind):
30-
if isinstance(kind, dpt.dtype):
31-
return dtype == kind
32-
33-
elif isinstance(kind, str):
28+
if isinstance(kind, str):
3429
if kind == "bool":
3530
return dtype.kind == "b"
3631
elif kind == "signed integer":
@@ -54,15 +49,21 @@ def _isdtype_impl(dtype, kind):
5449
raise TypeError(f"Unsupported data type kind: {kind}")
5550

5651

57-
class __array_namespace_info__:
52+
__array_api_version__ = "2022.12"
53+
54+
55+
class Info:
56+
"""
57+
namespace returned by `__array_namespace_info__()`
58+
"""
59+
5860
def __init__(self):
5961
self._capabilities = {
6062
"boolean_indexing": True,
6163
"data_dependent_shapes": True,
6264
}
6365
self._all_dtypes = {
6466
"bool": dpt.bool,
65-
"float16": dpt.float16,
6667
"float32": dpt.float32,
6768
"float64": dpt.float64,
6869
"complex64": dpt.complex64,
@@ -78,41 +79,129 @@ def __init__(self):
7879
}
7980

8081
def capabilities(self):
82+
"""
83+
Returns a dictionary of `dpctl`'s capabilities.
84+
85+
Returns:
86+
dict:
87+
dictionary of `dpctl`'s capabilities
88+
- `boolean_indexing`: bool
89+
- `data_dependent_shapes`: bool
90+
"""
8191
return self._capabilities.copy()
8292

8393
def default_device(self):
94+
"""
95+
Returns the default SYCL device.
96+
"""
8497
return dpctl.select_default_device()
8598

8699
def default_dtypes(self, device=None):
100+
"""
101+
Returns a dictionary of default data types for `device`.
102+
103+
Args:
104+
device (Optional[dpctl.SyclDevice, dpctl.SyclQueue,
105+
dpctl.tensor.Device]):
106+
array API concept of device used in getting default data types.
107+
`device` can be `None` (in which case the default device is
108+
used), an instance of :class:`dpctl.SyclDevice` corresponding
109+
to a non-partitioned SYCL device, an instance of
110+
:class:`dpctl.SyclQueue`, or a `Device` object returned by
111+
:attr:`dpctl.tensor.usm_array.device`. Default: `None`.
112+
113+
Returns:
114+
dict:
115+
a dictionary of default data types for `device`
116+
- `real floating`: dtype
117+
- `complex floating`: dtype
118+
- `integral`: dtype
119+
- `indexing`: dtype
120+
"""
87121
if device is None:
88122
device = dpctl.select_default_device()
123+
elif isinstance(device, dpt.Device):
124+
device = device.sycl_device
89125
return {
90-
"real floating": default_device_fp_type(device),
91-
"complex floating": default_device_complex_type,
92-
"integral": default_device_int_type(device),
93-
"indexing": default_device_index_type(device),
126+
"real floating": dpt.dtype(default_device_fp_type(device)),
127+
"complex floating": dpt.dtype(default_device_complex_type(device)),
128+
"integral": dpt.dtype(default_device_int_type(device)),
129+
"indexing": dpt.dtype(default_device_index_type(device)),
94130
}
95131

96132
def dtypes(self, device=None, kind=None):
133+
"""
134+
Returns a dictionary of all Array API data types of a specified `kind`
135+
supported by `device`
136+
137+
This dictionary only includes data types supported by the array API.
138+
139+
See [array API](array_api).
140+
141+
[array_api]: https://data-apis.org/array-api/latest/
142+
143+
Args:
144+
device (Optional[dpctl.SyclDevice, dpctl.SyclQueue,
145+
dpctl.tensor.Device, str]):
146+
array API concept of device used in getting default data types.
147+
`device` can be `None` (in which case the default device is
148+
used), an instance of :class:`dpctl.SyclDevice` corresponding
149+
to a non-partitioned SYCL device, an instance of
150+
:class:`dpctl.SyclQueue`, or a `Device` object returned by
151+
:attr:`dpctl.tensor.usm_array.device`. Default: `None`.
152+
153+
kind (Optional[str, Tuple[str, ...]]):
154+
data type kind.
155+
- if `kind` is `None`, returns a dictionary of all data types
156+
supported by `device`
157+
- if `kind` is a string, returns a dictionary containing the
158+
data types belonging to the data type kind specified.
159+
Supports:
160+
- "bool"
161+
- "signed integer"
162+
- "unsigned integer"
163+
- "integral"
164+
- "real floating"
165+
- "complex floating"
166+
- "numeric"
167+
- if `kind` is a tuple, the tuple represents a union of `kind`
168+
strings, and returns a dictionary containing data types
169+
corresponding to the-specified union.
170+
Default: `None`.
171+
172+
Returns:
173+
dict:
174+
a dictionary of the supported data types of the specified `kind`
175+
"""
97176
if device is None:
98177
device = dpctl.select_default_device()
99-
ignored_types = []
100-
if not device.has_aspect_fp16:
101-
ignored_types.append("float16")
102-
if not device.has_aspect_fp64:
103-
ignored_types.append("float64")
178+
elif isinstance(device, dpt.Device):
179+
device = device.sycl_device
180+
_fp64 = device.has_aspect_fp64
104181
if kind is None:
105182
return {
106183
key: val
107184
for key, val in self._all_dtypes.items()
108-
if key not in ignored_types
185+
if (key != "float64" or _fp64)
109186
}
110187
else:
111188
return {
112189
key: val
113190
for key, val in self._all_dtypes.items()
114-
if key not in ignored_types and _isdtype_impl(val, kind)
191+
if (key != "float64" or _fp64) and _isdtype_impl(val, kind)
115192
}
116193

117194
def devices(self):
195+
"""
196+
Returns a list of supported devices.
197+
"""
118198
return dpctl.get_devices()
199+
200+
201+
def __array_namespace_info__():
202+
"""__array_namespace_info__()
203+
204+
Returns a namespace with Array API namespace inspection utilities.
205+
206+
"""
207+
return Info()

0 commit comments

Comments
 (0)