Skip to content

Commit f2af753

Browse files
authored
Add array API inspection utilities to dpctl.tensor (#1469)
* Adds __array_namespace_info__ inspection utility This inspection utility is coming to the array API specification in the near future * Set __array_api_version__ to "2022.12" * Remove --ci from array API conformity workflow * Adds __array_namespace_info__ docstrings Disallows dtypes for `kind` kwarg in __array_namespace_info__().dtypes Removes `float16` from dtypes listed by __array_namespace_info__ as per spec Permits dpctl.tensor.Device objects in device keyword arguments in array API inspection utilities * Adds tests for array API inspection
1 parent f686102 commit f2af753

File tree

4 files changed

+374
-1
lines changed

4 files changed

+374
-1
lines changed

.github/workflows/conda-package.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ jobs:
666666
python -c "import dpctl; dpctl.lsplatform()"
667667
export ARRAY_API_TESTS_MODULE=dpctl.tensor
668668
cd /home/runner/work/array-api-tests
669-
pytest --ci --json-report --json-report-file=$FILE array_api_tests/ || true
669+
pytest --json-report --json-report-file=$FILE array_api_tests/ || true
670670
- name: Set Github environment variables
671671
shell: bash -l {0}
672672
run: |

dpctl/tensor/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
from dpctl.tensor._usmarray import usm_ndarray
9494
from dpctl.tensor._utility_functions import all, any
9595

96+
from ._array_api import __array_api_version__, __array_namespace_info__
9697
from ._clip import clip
9798
from ._constants import e, inf, nan, newaxis, pi
9899
from ._elementwise_funcs import (
@@ -335,4 +336,6 @@
335336
"clip",
336337
"logsumexp",
337338
"reduce_hypot",
339+
"__array_api_version__",
340+
"__array_namespace_info__",
338341
]

dpctl/tensor/_array_api.py

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import dpctl
18+
import dpctl.tensor as dpt
19+
from dpctl.tensor._tensor_impl import (
20+
default_device_complex_type,
21+
default_device_fp_type,
22+
default_device_index_type,
23+
default_device_int_type,
24+
)
25+
26+
27+
def _isdtype_impl(dtype, kind):
28+
if isinstance(kind, str):
29+
if kind == "bool":
30+
return dtype.kind == "b"
31+
elif kind == "signed integer":
32+
return dtype.kind == "i"
33+
elif kind == "unsigned integer":
34+
return dtype.kind == "u"
35+
elif kind == "integral":
36+
return dtype.kind in "iu"
37+
elif kind == "real floating":
38+
return dtype.kind == "f"
39+
elif kind == "complex floating":
40+
return dtype.kind == "c"
41+
elif kind == "numeric":
42+
return dtype.kind in "iufc"
43+
else:
44+
raise ValueError(f"Unrecognized data type kind: {kind}")
45+
46+
elif isinstance(kind, tuple):
47+
return any(_isdtype_impl(dtype, k) for k in kind)
48+
else:
49+
raise TypeError(f"Unsupported data type kind: {kind}")
50+
51+
52+
__array_api_version__ = "2022.12"
53+
54+
55+
class Info:
56+
"""
57+
namespace returned by `__array_namespace_info__()`
58+
"""
59+
60+
def __init__(self):
61+
self._capabilities = {
62+
"boolean_indexing": True,
63+
"data_dependent_shapes": True,
64+
}
65+
self._all_dtypes = {
66+
"bool": dpt.bool,
67+
"float32": dpt.float32,
68+
"float64": dpt.float64,
69+
"complex64": dpt.complex64,
70+
"complex128": dpt.complex128,
71+
"int8": dpt.int8,
72+
"int16": dpt.int16,
73+
"int32": dpt.int32,
74+
"int64": dpt.int64,
75+
"uint8": dpt.uint8,
76+
"uint16": dpt.uint16,
77+
"uint32": dpt.uint32,
78+
"uint64": dpt.uint64,
79+
}
80+
81+
def capabilities(self):
82+
"""
83+
Returns a dictionary of `dpctl`'s capabilities.
84+
85+
Returns:
86+
dict:
87+
dictionary of `dpctl`'s capabilities
88+
- `boolean_indexing`: bool
89+
- `data_dependent_shapes`: bool
90+
"""
91+
return self._capabilities.copy()
92+
93+
def default_device(self):
94+
"""
95+
Returns the default SYCL device.
96+
"""
97+
return dpctl.select_default_device()
98+
99+
def default_dtypes(self, device=None):
100+
"""
101+
Returns a dictionary of default data types for `device`.
102+
103+
Args:
104+
device (Optional[dpctl.SyclDevice, dpctl.SyclQueue,
105+
dpctl.tensor.Device]):
106+
array API concept of device used in getting default data types.
107+
`device` can be `None` (in which case the default device is
108+
used), an instance of :class:`dpctl.SyclDevice` corresponding
109+
to a non-partitioned SYCL device, an instance of
110+
:class:`dpctl.SyclQueue`, or a `Device` object returned by
111+
:attr:`dpctl.tensor.usm_array.device`. Default: `None`.
112+
113+
Returns:
114+
dict:
115+
a dictionary of default data types for `device`
116+
- `real floating`: dtype
117+
- `complex floating`: dtype
118+
- `integral`: dtype
119+
- `indexing`: dtype
120+
"""
121+
if device is None:
122+
device = dpctl.select_default_device()
123+
elif isinstance(device, dpt.Device):
124+
device = device.sycl_device
125+
return {
126+
"real floating": dpt.dtype(default_device_fp_type(device)),
127+
"complex floating": dpt.dtype(default_device_complex_type(device)),
128+
"integral": dpt.dtype(default_device_int_type(device)),
129+
"indexing": dpt.dtype(default_device_index_type(device)),
130+
}
131+
132+
def dtypes(self, device=None, kind=None):
133+
"""
134+
Returns a dictionary of all Array API data types of a specified `kind`
135+
supported by `device`
136+
137+
This dictionary only includes data types supported by the array API.
138+
139+
See [array API](array_api).
140+
141+
[array_api]: https://data-apis.org/array-api/latest/
142+
143+
Args:
144+
device (Optional[dpctl.SyclDevice, dpctl.SyclQueue,
145+
dpctl.tensor.Device, str]):
146+
array API concept of device used in getting default data types.
147+
`device` can be `None` (in which case the default device is
148+
used), an instance of :class:`dpctl.SyclDevice` corresponding
149+
to a non-partitioned SYCL device, an instance of
150+
:class:`dpctl.SyclQueue`, or a `Device` object returned by
151+
:attr:`dpctl.tensor.usm_array.device`. Default: `None`.
152+
153+
kind (Optional[str, Tuple[str, ...]]):
154+
data type kind.
155+
- if `kind` is `None`, returns a dictionary of all data types
156+
supported by `device`
157+
- if `kind` is a string, returns a dictionary containing the
158+
data types belonging to the data type kind specified.
159+
Supports:
160+
- "bool"
161+
- "signed integer"
162+
- "unsigned integer"
163+
- "integral"
164+
- "real floating"
165+
- "complex floating"
166+
- "numeric"
167+
- if `kind` is a tuple, the tuple represents a union of `kind`
168+
strings, and returns a dictionary containing data types
169+
corresponding to the-specified union.
170+
Default: `None`.
171+
172+
Returns:
173+
dict:
174+
a dictionary of the supported data types of the specified `kind`
175+
"""
176+
if device is None:
177+
device = dpctl.select_default_device()
178+
elif isinstance(device, dpt.Device):
179+
device = device.sycl_device
180+
_fp64 = device.has_aspect_fp64
181+
if kind is None:
182+
return {
183+
key: val
184+
for key, val in self._all_dtypes.items()
185+
if (key != "float64" or _fp64)
186+
}
187+
else:
188+
return {
189+
key: val
190+
for key, val in self._all_dtypes.items()
191+
if (key != "float64" or _fp64) and _isdtype_impl(val, kind)
192+
}
193+
194+
def devices(self):
195+
"""
196+
Returns a list of supported devices.
197+
"""
198+
return dpctl.get_devices()
199+
200+
201+
def __array_namespace_info__():
202+
"""__array_namespace_info__()
203+
204+
Returns a namespace with Array API namespace inspection utilities.
205+
206+
"""
207+
return Info()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl
20+
import dpctl.tensor as dpt
21+
from dpctl.tensor._tensor_impl import (
22+
default_device_complex_type,
23+
default_device_fp_type,
24+
default_device_index_type,
25+
default_device_int_type,
26+
)
27+
28+
_dtypes_no_fp16_fp64 = {
29+
"bool": dpt.bool,
30+
"float32": dpt.float32,
31+
"complex64": dpt.complex64,
32+
"complex128": dpt.complex128,
33+
"int8": dpt.int8,
34+
"int16": dpt.int16,
35+
"int32": dpt.int32,
36+
"int64": dpt.int64,
37+
"uint8": dpt.uint8,
38+
"uint16": dpt.uint16,
39+
"uint32": dpt.uint32,
40+
"uint64": dpt.uint64,
41+
}
42+
43+
44+
class MockDevice:
45+
def __init__(self, fp16: bool, fp64: bool):
46+
self.has_aspect_fp16 = fp16
47+
self.has_aspect_fp64 = fp64
48+
49+
50+
def test_array_api_inspection_methods():
51+
info = dpt.__array_namespace_info__()
52+
assert info.capabilities()
53+
assert info.default_device()
54+
assert info.default_dtypes()
55+
assert info.devices()
56+
assert info.dtypes()
57+
58+
59+
def test_array_api_inspection_default_device():
60+
assert (
61+
dpt.__array_namespace_info__().default_device()
62+
== dpctl.select_default_device()
63+
)
64+
65+
66+
def test_array_api_inspection_devices():
67+
devices1 = dpt.__array_namespace_info__().devices()
68+
devices2 = dpctl.get_devices()
69+
assert len(devices1) == len(devices2)
70+
assert devices1 == devices2
71+
72+
73+
def test_array_api_inspection_capabilities():
74+
capabilities = dpt.__array_namespace_info__().capabilities()
75+
assert capabilities["boolean_indexing"]
76+
assert capabilities["data_dependent_shapes"]
77+
78+
79+
def test_array_api_inspection_default_dtypes():
80+
dev = dpctl.select_default_device()
81+
82+
int_dt = default_device_int_type(dev)
83+
ind_dt = default_device_index_type(dev)
84+
fp_dt = default_device_fp_type(dev)
85+
cm_dt = default_device_complex_type(dev)
86+
87+
info = dpt.__array_namespace_info__()
88+
default_dts_nodev = info.default_dtypes()
89+
default_dts_dev = info.default_dtypes(dev)
90+
91+
assert (
92+
int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"]
93+
)
94+
assert (
95+
ind_dt == default_dts_nodev["indexing"] == default_dts_dev["indexing"]
96+
)
97+
assert (
98+
fp_dt
99+
== default_dts_nodev["real floating"]
100+
== default_dts_dev["real floating"]
101+
)
102+
assert (
103+
cm_dt
104+
== default_dts_nodev["complex floating"]
105+
== default_dts_dev["complex floating"]
106+
)
107+
108+
109+
def test_array_api_inspection_default_device_dtypes():
110+
dev = dpctl.select_default_device()
111+
dtypes = _dtypes_no_fp16_fp64.copy()
112+
if dev.has_aspect_fp64:
113+
dtypes["float64"] = dpt.float64
114+
115+
assert dtypes == dpt.__array_namespace_info__().dtypes()
116+
117+
118+
@pytest.mark.parametrize("fp16", [True, False])
119+
@pytest.mark.parametrize("fp64", [True, False])
120+
def test_array_api_inspection_device_dtypes(fp16, fp64):
121+
dev = MockDevice(fp16, fp64)
122+
dtypes = _dtypes_no_fp16_fp64.copy()
123+
if fp64:
124+
dtypes["float64"] = dpt.float64
125+
126+
assert dtypes == dpt.__array_namespace_info__().dtypes(device=dev)
127+
128+
129+
def test_array_api_inspection_dtype_kind():
130+
info = dpt.__array_namespace_info__()
131+
132+
f_dtypes = info.dtypes(kind="real floating")
133+
assert all([_dt[1].kind == "f" for _dt in f_dtypes.items()])
134+
135+
i_dtypes = info.dtypes(kind="signed integer")
136+
assert all([_dt[1].kind == "i" for _dt in i_dtypes.items()])
137+
138+
u_dtypes = info.dtypes(kind="unsigned integer")
139+
assert all([_dt[1].kind == "u" for _dt in u_dtypes.items()])
140+
141+
ui_dtypes = info.dtypes(kind="unsigned integer")
142+
assert all([_dt[1].kind in "ui" for _dt in ui_dtypes.items()])
143+
144+
c_dtypes = info.dtypes(kind="complex floating")
145+
assert all([_dt[1].kind == "c" for _dt in c_dtypes.items()])
146+
147+
assert info.dtypes(kind="bool") == {"bool": dpt.bool}
148+
149+
_signed_ints = {
150+
"int8": dpt.int8,
151+
"int16": dpt.int16,
152+
"int32": dpt.int32,
153+
"int64": dpt.int64,
154+
}
155+
assert (
156+
info.dtypes(kind=("signed integer", "signed integer")) == _signed_ints
157+
)
158+
assert (
159+
info.dtypes(
160+
kind=("integral", "bool", "real floating", "complex floating")
161+
)
162+
== info.dtypes()
163+
)

0 commit comments

Comments
 (0)