Skip to content

Commit 85d4a59

Browse files
committed
Use _isdtype_impl in __array_namespace_info__ and isdtype
This change reduces redundancy
1 parent 919955f commit 85d4a59

File tree

2 files changed

+21
-43
lines changed

2 files changed

+21
-43
lines changed

dpctl/tensor/_array_api.py

+3-29
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,9 @@
2323
default_device_int_type,
2424
)
2525

26-
__array_api_version__ = "2022.12"
27-
28-
29-
def _isdtype_impl(dtype, kind):
30-
if isinstance(kind, dpt.dtype):
31-
return dtype == kind
26+
from ._data_types import _isdtype_impl
3227

33-
elif isinstance(kind, str):
34-
if kind == "bool":
35-
return dtype.kind == "b"
36-
elif kind == "signed integer":
37-
return dtype.kind == "i"
38-
elif kind == "unsigned integer":
39-
return dtype.kind == "u"
40-
elif kind == "integral":
41-
return dtype.kind in "iu"
42-
elif kind == "real floating":
43-
return dtype.kind == "f"
44-
elif kind == "complex floating":
45-
return dtype.kind == "c"
46-
elif kind == "numeric":
47-
return dtype.kind in "iufc"
48-
else:
49-
raise ValueError(f"Unrecognized data type kind: {kind}")
50-
51-
elif isinstance(kind, tuple):
52-
return any(_isdtype_impl(dtype, k) for k in kind)
53-
else:
54-
raise TypeError(f"Unsupported data type kind: {kind}")
28+
__array_api_version__ = "2022.12"
5529

5630

5731
class __array_namespace_info__:
@@ -88,7 +62,7 @@ def default_dtypes(self, device=None):
8862
device = dpctl.select_default_device()
8963
return {
9064
"real floating": default_device_fp_type(device),
91-
"complex floating": default_device_complex_type,
65+
"complex floating": default_device_complex_type(device),
9266
"integral": default_device_int_type(device),
9367
"indexing": default_device_index_type(device),
9468
}

dpctl/tensor/_data_types.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,7 @@
5050
complex128 = dtype("complex128")
5151

5252

53-
def isdtype(dtype_, kind):
54-
"""isdtype(dtype, kind)
55-
56-
Returns a boolean indicating whether a provided `dtype` is
57-
of a specified data type `kind`.
58-
59-
See [array API](array_api) for more information.
60-
61-
[array_api]: https://data-apis.org/array-api/latest/
62-
"""
63-
64-
if not isinstance(dtype_, dtype):
65-
raise TypeError(f"Expected instance of `dpt.dtype`, got {dtype_}")
66-
53+
def _isdtype_impl(dtype_, kind):
6754
if isinstance(kind, dtype):
6855
return dtype_ == kind
6956

@@ -92,6 +79,23 @@ def isdtype(dtype_, kind):
9279
raise TypeError(f"Unsupported data type kind: {kind}")
9380

9481

82+
def isdtype(dtype_, kind):
83+
"""isdtype(dtype, kind)
84+
85+
Returns a boolean indicating whether a provided `dtype` is
86+
of a specified data type `kind`.
87+
88+
See [array API](array_api) for more information.
89+
90+
[array_api]: https://data-apis.org/array-api/latest/
91+
"""
92+
93+
if not isinstance(dtype_, dtype):
94+
raise TypeError(f"Expected instance of `dpt.dtype`, got {dtype_}")
95+
96+
return _isdtype_impl(dtype_, kind)
97+
98+
9599
def _get_dtype(inp_dt, sycl_obj, ref_type=None):
96100
"""
97101
Type inference utility to construct data type

0 commit comments

Comments
 (0)