File tree 1 file changed +9
-0
lines changed
1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
43
43
q = get_queue_or_skip ()
44
44
skip_if_dtype_not_supported (arg_dtype , q )
45
45
46
+ # test reduction for C-contiguous input
46
47
m = dpt .ones (100 , dtype = arg_dtype )
47
48
r = dpt .sum (m )
48
49
@@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
55
56
assert r .dtype .kind == "f"
56
57
elif m .dtype .kind == "c" :
57
58
assert r .dtype .kind == "c"
59
+
58
60
assert dpt .all (r == 100 )
59
61
62
+ # test reduction for strided input
60
63
m = dpt .ones (200 , dtype = arg_dtype )[:1 :- 2 ]
61
64
r = dpt .sum (m )
62
65
assert dpt .all (r == 99 )
63
66
67
+ # test reduction for strided input which can be simplified
68
+ # to contiguous computation
69
+ m = dpt .ones (100 , dtype = arg_dtype )
70
+ r = dpt .sum (dpt .flip (m ))
71
+ assert dpt .all (r == 100 )
72
+
64
73
65
74
@pytest .mark .parametrize ("arg_dtype" , _all_dtypes )
66
75
@pytest .mark .parametrize ("out_dtype" , _all_dtypes [1 :])
You can’t perform that action at this time.
0 commit comments