Skip to content

Commit 13a4262

Browse files
committed
Adds tests for __array_namespace_info__
1 parent 85d4a59 commit 13a4262

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl
20+
import dpctl.tensor as dpt
21+
from dpctl.tensor._tensor_impl import (
22+
default_device_complex_type,
23+
default_device_fp_type,
24+
default_device_index_type,
25+
default_device_int_type,
26+
)
27+
28+
_dtypes_no_fp16_fp64 = {
29+
"bool": dpt.bool,
30+
"float32": dpt.float32,
31+
"complex64": dpt.complex64,
32+
"complex128": dpt.complex128,
33+
"int8": dpt.int8,
34+
"int16": dpt.int16,
35+
"int32": dpt.int32,
36+
"int64": dpt.int64,
37+
"uint8": dpt.uint8,
38+
"uint16": dpt.uint16,
39+
"uint32": dpt.uint32,
40+
"uint64": dpt.uint64,
41+
}
42+
43+
44+
class MockDevice:
45+
def __init__(self, fp16: bool, fp64: bool):
46+
self.has_aspect_fp16 = fp16
47+
self.has_aspect_fp64 = fp64
48+
49+
50+
def test_array_api_inspection_methods():
51+
info = dpt.__array_namespace_info__()
52+
assert info.capabilities()
53+
assert info.default_device()
54+
assert info.default_dtypes()
55+
assert info.devices()
56+
assert info.dtypes()
57+
58+
59+
def test_array_api_inspection_default_device():
60+
assert (
61+
dpt.__array_namespace_info__().default_device()
62+
== dpctl.select_default_device()
63+
)
64+
65+
66+
def test_array_api_inspection_devices():
67+
devices1 = dpt.__array_namespace_info__().devices()
68+
devices2 = dpctl.get_devices()
69+
assert len(devices1) == len(devices2)
70+
assert devices1 == devices2
71+
72+
73+
def test_array_api_inspection_capabilities():
74+
capabilities = dpt.__array_namespace_info__().capabilities()
75+
assert capabilities["boolean_indexing"]
76+
assert capabilities["data_dependent_shapes"]
77+
78+
79+
def test_array_api_inspection_default_dtypes():
80+
dev = dpctl.select_default_device()
81+
82+
int_dt = default_device_int_type(dev)
83+
ind_dt = default_device_index_type(dev)
84+
fp_dt = default_device_fp_type(dev)
85+
cm_dt = default_device_complex_type(dev)
86+
87+
info = dpt.__array_namespace_info__()
88+
default_dts_nodev = info.default_dtypes()
89+
default_dts_dev = info.default_dtypes(dev)
90+
91+
assert (
92+
int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"]
93+
)
94+
assert (
95+
ind_dt == default_dts_nodev["indexing"] == default_dts_dev["indexing"]
96+
)
97+
assert (
98+
fp_dt
99+
== default_dts_nodev["real floating"]
100+
== default_dts_dev["real floating"]
101+
)
102+
assert (
103+
cm_dt
104+
== default_dts_nodev["complex floating"]
105+
== default_dts_dev["complex floating"]
106+
)
107+
108+
109+
def test_array_api_inspection_default_device_dtypes():
110+
dev = dpctl.select_default_device()
111+
dtypes = _dtypes_no_fp16_fp64.copy()
112+
if dev.has_aspect_fp16:
113+
dtypes["float16"] = dpt.float16
114+
if dev.has_aspect_fp64:
115+
dtypes["float64"] = dpt.float64
116+
117+
assert dtypes == dpt.__array_namespace_info__().dtypes()
118+
119+
120+
@pytest.mark.parametrize("fp16", [True, False])
121+
@pytest.mark.parametrize("fp64", [True, False])
122+
def test_array_api_inspection_device_dtypes(fp16, fp64):
123+
dev = MockDevice(fp16, fp64)
124+
dtypes = _dtypes_no_fp16_fp64.copy()
125+
if fp16:
126+
dtypes["float16"] = dpt.float16
127+
if fp64:
128+
dtypes["float64"] = dpt.float64
129+
130+
assert dtypes == dpt.__array_namespace_info__().dtypes(device=dev)
131+
132+
133+
def test_array_api_inspection_dtype_kind():
134+
info = dpt.__array_namespace_info__()
135+
136+
f_dtypes = info.dtypes(kind="real floating")
137+
assert all([_dt[1].kind == "f" for _dt in f_dtypes.items()])
138+
139+
i_dtypes = info.dtypes(kind="signed integer")
140+
assert all([_dt[1].kind == "i" for _dt in i_dtypes.items()])
141+
142+
u_dtypes = info.dtypes(kind="unsigned integer")
143+
assert all([_dt[1].kind == "u" for _dt in u_dtypes.items()])
144+
145+
ui_dtypes = info.dtypes(kind="unsigned integer")
146+
assert all([_dt[1].kind in "ui" for _dt in ui_dtypes.items()])
147+
148+
c_dtypes = info.dtypes(kind="complex floating")
149+
assert all([_dt[1].kind == "c" for _dt in c_dtypes.items()])
150+
151+
assert info.dtypes(kind="bool") == {"bool": dpt.bool}
152+
153+
_fp32 = {"float32": dpt.float32}
154+
_signed_ints = {
155+
"int8": dpt.int8,
156+
"int16": dpt.int16,
157+
"int32": dpt.int32,
158+
"int64": dpt.int64,
159+
}
160+
assert info.dtypes(kind=dpt.float32) == _fp32
161+
assert info.dtypes(kind=(dpt.float32, dpt.float32)) == _fp32
162+
assert (
163+
info.dtypes(kind=("signed integer", "signed integer")) == _signed_ints
164+
)
165+
assert info.dtypes(kind=(dpt.float32, "bool")) == {
166+
"float32": dpt.float32,
167+
"bool": dpt.bool,
168+
}
169+
assert (
170+
info.dtypes(
171+
kind=("integral", "bool", "real floating", "complex floating")
172+
)
173+
== info.dtypes()
174+
)

0 commit comments

Comments
 (0)