Skip to content

Commit ab63295

Browse files
Add more tests to improve test coverage
1 parent e6c6ba1 commit ab63295

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

dpctl/tests/elementwise/test_type_utils.py

+39
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import numpy as np
1718
import pytest
1819

1920
import dpctl
@@ -236,3 +237,41 @@ def test_can_cast_device():
236237
# can't safely cast inexact type to inexact type of lesser precision
237238
assert not tu._can_cast(dpt.float32, dpt.float16, True, False)
238239
assert not tu._can_cast(dpt.float64, dpt.float32, False, True)
240+
241+
242+
def test_acceptance_fns():
243+
"""Check type promotion acceptance functions"""
244+
assert tu._acceptance_fn_reciprocal(dpt.float32, dpt.float32)
245+
246+
247+
def test_weak_types():
248+
wbt = tu.WeakBooleanType(True)
249+
assert wbt.get()
250+
assert tu._weak_type_num_kind(wbt) == 0
251+
252+
wit = tu.WeakIntegralType(7)
253+
assert wit.get() == 7
254+
assert tu._weak_type_num_kind(wit) == 1
255+
256+
wft = tu.WeakFloatingType(3.1415926)
257+
assert wft.get() == 3.1415926
258+
assert tu._weak_type_num_kind(wft) == 2
259+
260+
wct = tu.WeakComplexType(2.0 + 3.0j)
261+
assert wct.get() == 2 + 3j
262+
assert tu._weak_type_num_kind(wct) == 3
263+
264+
265+
def test_arg_validation():
266+
with pytest.raises(TypeError):
267+
tu._weak_type_num_kind(dict())
268+
269+
with pytest.raises(TypeError):
270+
tu._strong_dtype_num_kind(Ellipsis)
271+
272+
with pytest.raises(ValueError):
273+
tu._strong_dtype_num_kind(np.dtype("O"))
274+
275+
wt = tu.WeakFloatingType(2.0)
276+
with pytest.raises(ValueError):
277+
tu._resolve_weak_types(wt, wt, None)

dpctl/tests/test_usm_ndarray_manipulation.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -935,13 +935,21 @@ def test_can_cast():
935935
def test_result_type():
936936
q = get_queue_or_skip()
937937

938-
X = [dpt.ones((2), dtype=dpt.int16, sycl_queue=q), dpt.int32, "int64"]
939-
X_np = [np.ones((2), dtype=np.int16), np.int32, "int64"]
938+
usm_ar = dpt.ones((2), dtype=dpt.int16, sycl_queue=q)
939+
np_ar = dpt.asnumpy(usm_ar)
940+
941+
X = [usm_ar, dpt.int32, "int64", usm_ar]
942+
X_np = [np_ar, np.int32, "int64", np_ar]
943+
944+
assert dpt.result_type(*X) == np.result_type(*X_np)
945+
946+
X = [usm_ar, dpt.int32, "int64", True]
947+
X_np = [np_ar, np.int32, "int64", True]
940948

941949
assert dpt.result_type(*X) == np.result_type(*X_np)
942950

943-
X = [dpt.ones((2), dtype=dpt.int16, sycl_queue=q), dpt.int32, "int64", 2]
944-
X_np = [np.ones((2), dtype=np.int16), np.int32, "int64", 2]
951+
X = [usm_ar, dpt.int32, "int64", 2]
952+
X_np = [np_ar, np.int32, "int64", 2]
945953

946954
assert dpt.result_type(*X) == np.result_type(*X_np)
947955

@@ -950,6 +958,16 @@ def test_result_type():
950958

951959
assert dpt.result_type(*X) == np.result_type(*X_np)
952960

961+
X = [usm_ar, dpt.int32, "int64", 2.0]
962+
X_np = [np_ar, np.int32, "int64", 2.0]
963+
964+
assert dpt.result_type(*X).kind == np.result_type(*X_np).kind
965+
966+
X = [usm_ar, dpt.int32, "int64", 2.0 + 1j]
967+
X_np = [np_ar, np.int32, "int64", 2.0 + 1j]
968+
969+
assert dpt.result_type(*X).kind == np.result_type(*X_np).kind
970+
953971

954972
def test_swapaxes_1d():
955973
get_queue_or_skip()

0 commit comments

Comments
 (0)