Skip to content

Commit 9bbdbdf

Browse files
committed
update tests to reduce duplication
1 parent 28f4fe9 commit 9bbdbdf

File tree

1 file changed

+19
-49
lines changed

1 file changed

+19
-49
lines changed

tests/test_statistics.py

+19-49
Original file line numberDiff line numberDiff line change
@@ -22,97 +22,67 @@ def test_median(dtype, size):
2222
assert_allclose(dpnp_res, np_res)
2323

2424

25+
@pytest.mark.parametrize("func", ["max", "min"])
2526
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)])
2627
@pytest.mark.parametrize("keepdims", [False, True])
2728
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
28-
def test_max_min(axis, keepdims, dtype):
29+
def test_max_min(func, axis, keepdims, dtype):
2930
a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8))
3031
ia = dpnp.array(a)
3132

32-
np_res = numpy.max(a, axis=axis, keepdims=keepdims)
33-
dpnp_res = dpnp.max(ia, axis=axis, keepdims=keepdims)
34-
35-
assert dpnp_res.shape == np_res.shape
36-
assert_allclose(dpnp_res, np_res)
37-
38-
np_res = numpy.min(a, axis=axis, keepdims=keepdims)
39-
dpnp_res = dpnp.min(ia, axis=axis, keepdims=keepdims)
33+
np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
34+
dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
4035

4136
assert dpnp_res.shape == np_res.shape
4237
assert_allclose(dpnp_res, np_res)
4338

4439

40+
@pytest.mark.parametrize("func", ["max", "min"])
4541
@pytest.mark.parametrize("axis", [None, 0, 1, -1])
4642
@pytest.mark.parametrize("keepdims", [False, True])
47-
def test_max_min_bool(axis, keepdims):
43+
def test_max_min_bool(func, axis, keepdims):
4844
a = numpy.arange(2, dtype=dpnp.bool)
4945
a = numpy.tile(a, (2, 2))
5046
ia = dpnp.array(a)
5147

52-
np_res = numpy.max(a, axis=axis, keepdims=keepdims)
53-
dpnp_res = dpnp.max(ia, axis=axis, keepdims=keepdims)
54-
55-
assert dpnp_res.shape == np_res.shape
56-
assert_allclose(dpnp_res, np_res)
57-
58-
np_res = numpy.min(a, axis=axis, keepdims=keepdims)
59-
dpnp_res = dpnp.min(ia, axis=axis, keepdims=keepdims)
48+
np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
49+
dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
6050

6151
assert dpnp_res.shape == np_res.shape
6252
assert_allclose(dpnp_res, np_res)
6353

6454

65-
def test_max_min_out():
55+
@pytest.mark.parametrize("func", ["max", "min"])
56+
def test_max_min_out(func):
6657
a = numpy.arange(6).reshape((2, 3))
6758
ia = dpnp.array(a)
6859

69-
np_res = numpy.max(a, axis=0)
60+
np_res = getattr(numpy, func)(a, axis=0)
7061
dpnp_res = dpnp.array(numpy.empty_like(np_res))
71-
dpnp.max(ia, axis=0, out=dpnp_res)
62+
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
7263
assert_allclose(dpnp_res, np_res)
7364

7465
dpnp_res = dpt.asarray(numpy.empty_like(np_res))
75-
dpnp.max(ia, axis=0, out=dpnp_res)
66+
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
7667
assert_allclose(dpnp_res, np_res)
7768

7869
dpnp_res = numpy.empty_like(np_res)
7970
with pytest.raises(TypeError):
80-
dpnp.max(ia, axis=0, out=dpnp_res)
71+
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
8172

8273
dpnp_res = dpnp.array(numpy.empty((2, 3)))
8374
with pytest.raises(ValueError):
84-
dpnp.max(ia, axis=0, out=dpnp_res)
75+
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
8576

86-
np_res = numpy.min(a, axis=0)
87-
dpnp_res = dpnp.array(numpy.empty_like(np_res))
88-
dpnp.min(ia, axis=0, out=dpnp_res)
89-
assert_allclose(dpnp_res, np_res)
9077

91-
dpnp_res = dpt.asarray(numpy.empty_like(np_res))
92-
dpnp.min(ia, axis=0, out=dpnp_res)
93-
assert_allclose(dpnp_res, np_res)
94-
95-
dpnp_res = numpy.empty_like(np_res)
96-
with pytest.raises(TypeError):
97-
dpnp.min(ia, axis=0, out=dpnp_res)
98-
99-
dpnp_res = dpnp.array(numpy.empty((2, 3)))
100-
with pytest.raises(ValueError):
101-
dpnp.min(ia, axis=0, out=dpnp_res)
102-
103-
104-
def test_max_min_NotImplemented():
78+
@pytest.mark.parametrize("func", ["max", "min"])
79+
def test_max_min_NotImplemented(func):
10580
ia = dpnp.arange(5)
10681

10782
with pytest.raises(NotImplementedError):
108-
dpnp.max(ia, where=False)
109-
with pytest.raises(NotImplementedError):
110-
dpnp.max(ia, initial=6)
111-
112-
with pytest.raises(NotImplementedError):
113-
dpnp.min(ia, where=False)
83+
getattr(dpnp, func)(ia, where=False)
11484
with pytest.raises(NotImplementedError):
115-
dpnp.min(ia, initial=6)
85+
getattr(dpnp, func)(ia, initial=6)
11686

11787

11888
@pytest.mark.usefixtures("allow_fall_back_on_numpy")

0 commit comments

Comments
 (0)