Skip to content

Commit 89b4762

Browse files
committed
Add more test for 'out' parameter
1 parent 11abb7c commit 89b4762

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

tests/test_mathematical.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -656,14 +656,46 @@ def test_add(self, dtype):
656656
assert_allclose(expected, result)
657657
assert_allclose(out, dp_out)
658658

659-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True, no_none=True))
660-
def test_invalid_dtype(self, dtype):
661-
dp_array1 = dpnp.arange(10, dtype=dpnp.complex64)
662-
dp_array2 = dpnp.arange(5, 15, dtype=dpnp.complex64)
663-
dp_out = dpnp.empty(10, dtype=dtype)
659+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
660+
def test_out_dtypes(self, dtype):
661+
size = 2 if dtype == dpnp.bool else 10
664662

665-
with pytest.raises(ValueError):
666-
dpnp.add(dp_array1, dp_array2, out=dp_out)
663+
np_array1 = numpy.arange(size, 2 * size, dtype=dtype)
664+
np_array2 = numpy.arange(size, dtype=dtype)
665+
np_out = numpy.empty(size, dtype=numpy.complex64)
666+
expected = numpy.add(np_array1, np_array2, out=np_out)
667+
668+
dp_array1 = dpnp.arange(size, 2 * size, dtype=dtype)
669+
dp_array2 = dpnp.arange(size, dtype=dtype)
670+
dp_out = dpnp.empty(size, dtype=dpnp.complex64)
671+
result = dpnp.add(dp_array1, dp_array2, out=dp_out)
672+
673+
assert_array_equal(expected, result)
674+
675+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
676+
def test_out_overlap(self, dtype):
677+
size = 1 if dtype == dpnp.bool else 15
678+
679+
np_a = numpy.arange(2 * size, dtype=dtype)
680+
expected = numpy.add(np_a[size::], np_a[::2], out=np_a[:size:])
681+
682+
dp_a = dpnp.arange(2 * size, dtype=dtype)
683+
result = dpnp.add(dp_a[size::], dp_a[::2], out=dp_a[:size:])
684+
685+
assert_allclose(expected, result)
686+
assert_allclose(dp_a, np_a)
687+
688+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_none=True))
689+
def test_inplace_strided_out(self, dtype):
690+
size = 21
691+
692+
np_a = numpy.arange(size, dtype=dtype)
693+
np_a[::3] += 4
694+
695+
dp_a = dpnp.arange(size, dtype=dtype)
696+
dp_a[::3] += 4
697+
698+
assert_allclose(dp_a, np_a)
667699

668700
@pytest.mark.parametrize("shape",
669701
[(0,), (15, ), (2, 2)],
@@ -715,7 +747,7 @@ def test_out_dtypes(self, dtype):
715747
np_out = numpy.empty(size, dtype=numpy.complex64)
716748
expected = numpy.power(np_array1, np_array2, out=np_out)
717749

718-
dp_array1 = dpnp.arange(size, 2*size, dtype=dtype)
750+
dp_array1 = dpnp.arange(size, 2 * size, dtype=dtype)
719751
dp_array2 = dpnp.arange(size, dtype=dtype)
720752
dp_out = dpnp.empty(size, dtype=dpnp.complex64)
721753
result = dpnp.power(dp_array1, dp_array2, out=dp_out)

tests/test_strides.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_strides_true_devide(dtype, shape):
217217

218218

219219
@pytest.mark.parametrize("func_name",
220-
["power"])
220+
["add", "power"])
221221
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
222222
def test_strided_out_2args(func_name, dtype):
223223
np_out = numpy.ones((5, 3, 2))[::3]
@@ -236,7 +236,7 @@ def test_strided_out_2args(func_name, dtype):
236236

237237

238238
@pytest.mark.parametrize("func_name",
239-
["power"])
239+
["add", "power"])
240240
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
241241
def test_strided_in_out_2args(func_name, dtype):
242242
sh = (3, 4, 2)
@@ -258,7 +258,7 @@ def test_strided_in_out_2args(func_name, dtype):
258258

259259

260260
@pytest.mark.parametrize("func_name",
261-
["power"])
261+
["add", "power"])
262262
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
263263
def test_strided_in_out_2args_diff_out_dtype(func_name, dtype):
264264
sh = (3, 3, 2)
@@ -280,7 +280,7 @@ def test_strided_in_out_2args_diff_out_dtype(func_name, dtype):
280280

281281

282282
@pytest.mark.parametrize("func_name",
283-
["power"])
283+
["add", "power"])
284284
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True))
285285
def test_strided_in_2args_overlap(func_name, dtype):
286286
size = 5
@@ -296,7 +296,7 @@ def test_strided_in_2args_overlap(func_name, dtype):
296296

297297

298298
@pytest.mark.parametrize("func_name",
299-
["power"])
299+
["add", "power"])
300300
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True))
301301
def test_strided_in_out_2args_overlap(func_name, dtype):
302302
sh = (4, 3, 2)

0 commit comments

Comments
 (0)