Skip to content

Commit c63c545

Browse files
Expand test to cover non-contig. input that can be simplified into one
1 parent b3e9465 commit c63c545

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

dpctl/tests/test_tensor_sum.py

+9
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
4343
q = get_queue_or_skip()
4444
skip_if_dtype_not_supported(arg_dtype, q)
4545

46+
# test reduction for C-contiguous input
4647
m = dpt.ones(100, dtype=arg_dtype)
4748
r = dpt.sum(m)
4849

@@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
5556
assert r.dtype.kind == "f"
5657
elif m.dtype.kind == "c":
5758
assert r.dtype.kind == "c"
59+
5860
assert dpt.all(r == 100)
5961

62+
# test reduction for strided input
6063
m = dpt.ones(200, dtype=arg_dtype)[:1:-2]
6164
r = dpt.sum(m)
6265
assert dpt.all(r == 99)
6366

67+
# test reduction for strided input which can be simplified
68+
# to contiguous computation
69+
m = dpt.ones(100, dtype=arg_dtype)
70+
r = dpt.sum(dpt.flip(m))
71+
assert dpt.all(r == 100)
72+
6473

6574
@pytest.mark.parametrize("arg_dtype", _all_dtypes)
6675
@pytest.mark.parametrize("out_dtype", _all_dtypes[1:])

0 commit comments

Comments
 (0)