Skip to content

Commit 412692a

Browse files
committed
Adds tests for array API inspection
1 parent 68875dc commit 412692a

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl
20+
import dpctl.tensor as dpt
21+
from dpctl.tensor._tensor_impl import (
22+
default_device_complex_type,
23+
default_device_fp_type,
24+
default_device_index_type,
25+
default_device_int_type,
26+
)
27+
28+
_dtypes_no_fp16_fp64 = {
29+
"bool": dpt.bool,
30+
"float32": dpt.float32,
31+
"complex64": dpt.complex64,
32+
"complex128": dpt.complex128,
33+
"int8": dpt.int8,
34+
"int16": dpt.int16,
35+
"int32": dpt.int32,
36+
"int64": dpt.int64,
37+
"uint8": dpt.uint8,
38+
"uint16": dpt.uint16,
39+
"uint32": dpt.uint32,
40+
"uint64": dpt.uint64,
41+
}
42+
43+
44+
class MockDevice:
45+
def __init__(self, fp16: bool, fp64: bool):
46+
self.has_aspect_fp16 = fp16
47+
self.has_aspect_fp64 = fp64
48+
49+
50+
def test_array_api_inspection_methods():
51+
info = dpt.__array_namespace_info__()
52+
assert info.capabilities()
53+
assert info.default_device()
54+
assert info.default_dtypes()
55+
assert info.devices()
56+
assert info.dtypes()
57+
58+
59+
def test_array_api_inspection_default_device():
60+
assert (
61+
dpt.__array_namespace_info__().default_device()
62+
== dpctl.select_default_device()
63+
)
64+
65+
66+
def test_array_api_inspection_devices():
67+
devices1 = dpt.__array_namespace_info__().devices()
68+
devices2 = dpctl.get_devices()
69+
assert len(devices1) == len(devices2)
70+
assert devices1 == devices2
71+
72+
73+
def test_array_api_inspection_capabilities():
74+
capabilities = dpt.__array_namespace_info__().capabilities()
75+
assert capabilities["boolean_indexing"]
76+
assert capabilities["data_dependent_shapes"]
77+
78+
79+
def test_array_api_inspection_default_dtypes():
80+
dev = dpctl.select_default_device()
81+
82+
int_dt = default_device_int_type(dev)
83+
ind_dt = default_device_index_type(dev)
84+
fp_dt = default_device_fp_type(dev)
85+
cm_dt = default_device_complex_type(dev)
86+
87+
info = dpt.__array_namespace_info__()
88+
default_dts_nodev = info.default_dtypes()
89+
default_dts_dev = info.default_dtypes(dev)
90+
91+
assert (
92+
int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"]
93+
)
94+
assert (
95+
ind_dt == default_dts_nodev["indexing"] == default_dts_dev["indexing"]
96+
)
97+
assert (
98+
fp_dt
99+
== default_dts_nodev["real floating"]
100+
== default_dts_dev["real floating"]
101+
)
102+
assert (
103+
cm_dt
104+
== default_dts_nodev["complex floating"]
105+
== default_dts_dev["complex floating"]
106+
)
107+
108+
109+
def test_array_api_inspection_default_device_dtypes():
110+
dev = dpctl.select_default_device()
111+
dtypes = _dtypes_no_fp16_fp64.copy()
112+
if dev.has_aspect_fp64:
113+
dtypes["float64"] = dpt.float64
114+
115+
assert dtypes == dpt.__array_namespace_info__().dtypes()
116+
117+
118+
@pytest.mark.parametrize("fp16", [True, False])
119+
@pytest.mark.parametrize("fp64", [True, False])
120+
def test_array_api_inspection_device_dtypes(fp16, fp64):
121+
dev = MockDevice(fp16, fp64)
122+
dtypes = _dtypes_no_fp16_fp64.copy()
123+
if fp64:
124+
dtypes["float64"] = dpt.float64
125+
126+
assert dtypes == dpt.__array_namespace_info__().dtypes(device=dev)
127+
128+
129+
def test_array_api_inspection_dtype_kind():
130+
info = dpt.__array_namespace_info__()
131+
132+
f_dtypes = info.dtypes(kind="real floating")
133+
assert all([_dt[1].kind == "f" for _dt in f_dtypes.items()])
134+
135+
i_dtypes = info.dtypes(kind="signed integer")
136+
assert all([_dt[1].kind == "i" for _dt in i_dtypes.items()])
137+
138+
u_dtypes = info.dtypes(kind="unsigned integer")
139+
assert all([_dt[1].kind == "u" for _dt in u_dtypes.items()])
140+
141+
ui_dtypes = info.dtypes(kind="unsigned integer")
142+
assert all([_dt[1].kind in "ui" for _dt in ui_dtypes.items()])
143+
144+
c_dtypes = info.dtypes(kind="complex floating")
145+
assert all([_dt[1].kind == "c" for _dt in c_dtypes.items()])
146+
147+
assert info.dtypes(kind="bool") == {"bool": dpt.bool}
148+
149+
_signed_ints = {
150+
"int8": dpt.int8,
151+
"int16": dpt.int16,
152+
"int32": dpt.int32,
153+
"int64": dpt.int64,
154+
}
155+
assert (
156+
info.dtypes(kind=("signed integer", "signed integer")) == _signed_ints
157+
)
158+
assert (
159+
info.dtypes(
160+
kind=("integral", "bool", "real floating", "complex floating")
161+
)
162+
== info.dtypes()
163+
)

0 commit comments

Comments
 (0)