Skip to content

Commit 01e9d9c

Browse files
Add more tests to improve test coverage
1 parent e6c6ba1 commit 01e9d9c

File tree

2 files changed

+64
-4
lines changed

2 files changed

+64
-4
lines changed

dpctl/tests/elementwise/test_type_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import numpy as np
1718
import pytest
1819

1920
import dpctl
@@ -236,3 +237,44 @@ def test_can_cast_device():
236237
# can't safely cast inexact type to inexact type of lesser precision
237238
assert not tu._can_cast(dpt.float32, dpt.float16, True, False)
238239
assert not tu._can_cast(dpt.float64, dpt.float32, False, True)
240+
241+
242+
def test_acceptance_fns():
243+
"""Check type promotion acceptance functions"""
244+
dev = dpctl.SyclDevice()
245+
assert tu._acceptance_fn_reciprocal(
246+
dpt.float32, dpt.float32, dpt.float32, dev
247+
)
248+
249+
250+
def test_weak_types():
251+
wbt = tu.WeakBooleanType(True)
252+
assert wbt.get()
253+
assert tu._weak_type_num_kind(wbt) == 0
254+
255+
wit = tu.WeakIntegralType(7)
256+
assert wit.get() == 7
257+
assert tu._weak_type_num_kind(wit) == 1
258+
259+
wft = tu.WeakFloatingType(3.1415926)
260+
assert wft.get() == 3.1415926
261+
assert tu._weak_type_num_kind(wft) == 2
262+
263+
wct = tu.WeakComplexType(2.0 + 3.0j)
264+
assert wct.get() == 2 + 3j
265+
assert tu._weak_type_num_kind(wct) == 3
266+
267+
268+
def test_arg_validation():
269+
with pytest.raises(TypeError):
270+
tu._weak_type_num_kind(dict())
271+
272+
with pytest.raises(TypeError):
273+
tu._strong_dtype_num_kind(Ellipsis)
274+
275+
with pytest.raises(ValueError):
276+
tu._strong_dtype_num_kind(np.dtype("O"))
277+
278+
wt = tu.WeakFloatingType(2.0)
279+
with pytest.raises(ValueError):
280+
tu._resolve_weak_types(wt, wt, None)

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -935,13 +935,21 @@ def test_can_cast():
935935
def test_result_type():
936936
q = get_queue_or_skip()
937937

938-
X = [dpt.ones((2), dtype=dpt.int16, sycl_queue=q), dpt.int32, "int64"]
939-
X_np = [np.ones((2), dtype=np.int16), np.int32, "int64"]
938+
usm_ar = dpt.ones((2), dtype=dpt.int16, sycl_queue=q)
939+
np_ar = dpt.asnumpy(usm_ar)
940+
941+
X = [usm_ar, dpt.int32, "int64", usm_ar]
942+
X_np = [np_ar, np.int32, "int64", np_ar]
943+
944+
assert dpt.result_type(*X) == np.result_type(*X_np)
945+
946+
X = [usm_ar, dpt.int32, "int64", True]
947+
X_np = [np_ar, np.int32, "int64", True]
940948

941949
assert dpt.result_type(*X) == np.result_type(*X_np)
942950

943-
X = [dpt.ones((2), dtype=dpt.int16, sycl_queue=q), dpt.int32, "int64", 2]
944-
X_np = [np.ones((2), dtype=np.int16), np.int32, "int64", 2]
951+
X = [usm_ar, dpt.int32, "int64", 2]
952+
X_np = [np_ar, np.int32, "int64", 2]
945953

946954
assert dpt.result_type(*X) == np.result_type(*X_np)
947955

@@ -950,6 +958,16 @@ def test_result_type():
950958

951959
assert dpt.result_type(*X) == np.result_type(*X_np)
952960

961+
X = [usm_ar, dpt.int32, "int64", 2.0]
962+
X_np = [np_ar, np.int32, "int64", 2.0]
963+
964+
assert dpt.result_type(*X).kind == np.result_type(*X_np).kind
965+
966+
X = [usm_ar, dpt.int32, "int64", 2.0 + 1j]
967+
X_np = [np_ar, np.int32, "int64", 2.0 + 1j]
968+
969+
assert dpt.result_type(*X).kind == np.result_type(*X_np).kind
970+
953971

954972
def test_swapaxes_1d():
955973
get_queue_or_skip()

0 commit comments

Comments
 (0)