Skip to content

Commit 07e5f86

Browse files
committed
address reviewers' comments
1 parent 5bfeb2c commit 07e5f86

File tree

2 files changed

+53
-46
lines changed

2 files changed

+53
-46
lines changed

tests/test_mathematical.py

+49-42
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .helper import (
1616
get_all_dtypes,
1717
get_float_complex_dtypes,
18+
get_float_dtypes,
1819
has_support_aspect64,
1920
is_cpu_device,
2021
is_win_platform,
@@ -597,119 +598,125 @@ def test_gradient_y1_dx(self, array, dx):
597598

598599

599600
class TestCeil:
600-
def test_ceil(self):
601+
@pytest.mark.parametrize("dtype", get_float_dtypes())
602+
def test_ceil(self, dtype):
601603
array_data = numpy.arange(10)
602-
out = numpy.empty(10, dtype=numpy.float64)
604+
out = numpy.empty(10, dtype)
603605

604606
# DPNP
605-
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
606-
dp_out = dpnp.array(out, dtype=dpnp.float64)
607+
dp_array = dpnp.array(array_data, dtype=dtype)
608+
dp_out = dpnp.array(out, dtype=dtype)
607609
result = dpnp.ceil(dp_array, out=dp_out)
608610

609611
# original
610-
np_array = numpy.array(array_data, dtype=numpy.float64)
612+
np_array = numpy.array(array_data, dtype=dtype)
611613
expected = numpy.ceil(np_array, out=out)
612614

613615
assert_array_equal(expected, result)
614616

615617
@pytest.mark.parametrize(
616-
"dtype",
617-
[numpy.float32, numpy.int64, numpy.int32],
618-
ids=["numpy.float32", "numpy.int64", "numpy.int32"],
618+
"dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True)
619619
)
620620
def test_invalid_dtype(self, dtype):
621-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
621+
dpnp_dtype = dpnp.float64 if has_support_aspect64() else dpnp.float32
622+
pytest.skip("similar data types") if dpnp_dtype == dtype else None
623+
dp_array = dpnp.arange(10, dtype=dpnp_dtype)
622624
dp_out = dpnp.empty(10, dtype=dtype)
623625

624-
with pytest.raises(ValueError):
626+
with pytest.raises(TypeError):
625627
dpnp.ceil(dp_array, out=dp_out)
626628

629+
@pytest.mark.parametrize("dtype", get_float_dtypes())
627630
@pytest.mark.parametrize(
628631
"shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"]
629632
)
630-
def test_invalid_shape(self, shape):
631-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
632-
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
633+
def test_invalid_shape(self, shape, dtype):
634+
dp_array = dpnp.arange(10, dtype=dtype)
635+
dp_out = dpnp.empty(shape, dtype=dtype)
633636

634-
with pytest.raises(ValueError):
637+
with pytest.raises(TypeError):
635638
dpnp.ceil(dp_array, out=dp_out)
636639

637640

638641
class TestFloor:
639-
def test_floor(self):
642+
@pytest.mark.parametrize("dtype", get_float_dtypes())
643+
def test_floor(self, dtype):
640644
array_data = numpy.arange(10)
641-
out = numpy.empty(10, dtype=numpy.float64)
645+
out = numpy.empty(10, dtype=dtype)
642646

643647
# DPNP
644-
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
645-
dp_out = dpnp.array(out, dtype=dpnp.float64)
648+
dp_array = dpnp.array(array_data, dtype=dtype)
649+
dp_out = dpnp.array(out, dtype=dtype)
646650
result = dpnp.floor(dp_array, out=dp_out)
647651

648652
# original
649-
np_array = numpy.array(array_data, dtype=numpy.float64)
653+
np_array = numpy.array(array_data, dtype=dtype)
650654
expected = numpy.floor(np_array, out=out)
651655

652656
assert_array_equal(expected, result)
653657

654658
@pytest.mark.parametrize(
655-
"dtype",
656-
[numpy.float32, numpy.int64, numpy.int32],
657-
ids=["numpy.float32", "numpy.int64", "numpy.int32"],
659+
"dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True)
658660
)
659661
def test_invalid_dtype(self, dtype):
660-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
662+
dpnp_dtype = dpnp.float64 if has_support_aspect64() else dpnp.float32
663+
pytest.skip("similar data types") if dpnp_dtype == dtype else None
664+
dp_array = dpnp.arange(10, dtype=dpnp_dtype)
661665
dp_out = dpnp.empty(10, dtype=dtype)
662666

663-
with pytest.raises(ValueError):
667+
with pytest.raises(TypeError):
664668
dpnp.floor(dp_array, out=dp_out)
665669

670+
@pytest.mark.parametrize("dtype", get_float_dtypes())
666671
@pytest.mark.parametrize(
667672
"shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"]
668673
)
669-
def test_invalid_shape(self, shape):
670-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
671-
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
674+
def test_invalid_shape(self, shape, dtype):
675+
dp_array = dpnp.arange(10, dtype=dtype)
676+
dp_out = dpnp.empty(shape, dtype=dtype)
672677

673-
with pytest.raises(ValueError):
678+
with pytest.raises(TypeError):
674679
dpnp.floor(dp_array, out=dp_out)
675680

676681

677682
class TestTrunc:
678-
def test_trunc(self):
683+
@pytest.mark.parametrize("dtype", get_float_dtypes())
684+
def test_trunc(self, dtype):
679685
array_data = numpy.arange(10)
680-
out = numpy.empty(10, dtype=numpy.float64)
686+
out = numpy.empty(10, dtype=dtype)
681687

682688
# DPNP
683-
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
684-
dp_out = dpnp.array(out, dtype=dpnp.float64)
689+
dp_array = dpnp.array(array_data, dtype=dtype)
690+
dp_out = dpnp.array(out, dtype=dtype)
685691
result = dpnp.trunc(dp_array, out=dp_out)
686692

687693
# original
688-
np_array = numpy.array(array_data, dtype=numpy.float64)
694+
np_array = numpy.array(array_data, dtype=dtype)
689695
expected = numpy.trunc(np_array, out=out)
690696

691697
assert_array_equal(expected, result)
692698

693699
@pytest.mark.parametrize(
694-
"dtype",
695-
[numpy.float32, numpy.int64, numpy.int32],
696-
ids=["numpy.float32", "numpy.int64", "numpy.int32"],
700+
"dtype", get_all_dtypes(no_bool=True, no_complex=True, no_none=True)
697701
)
698702
def test_invalid_dtype(self, dtype):
699-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
703+
dpnp_dtype = dpnp.float64 if has_support_aspect64() else dpnp.float32
704+
pytest.skip("similar data types") if dpnp_dtype == dtype else None
705+
dp_array = dpnp.arange(10, dtype=dpnp_dtype)
700706
dp_out = dpnp.empty(10, dtype=dtype)
701707

702-
with pytest.raises(ValueError):
708+
with pytest.raises(TypeError):
703709
dpnp.trunc(dp_array, out=dp_out)
704710

711+
@pytest.mark.parametrize("dtype", get_float_dtypes())
705712
@pytest.mark.parametrize(
706713
"shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"]
707714
)
708-
def test_invalid_shape(self, shape):
709-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
710-
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
715+
def test_invalid_shape(self, shape, dtype):
716+
dp_array = dpnp.arange(10, dtype=dtype)
717+
dp_out = dpnp.empty(shape, dtype=dtype)
711718

712-
with pytest.raises(ValueError):
719+
with pytest.raises(TypeError):
713720
dpnp.trunc(dp_array, out=dp_out)
714721

715722

tests/test_usm_type.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ def test_meshgrid(usm_type_x, usm_type_y):
216216
@pytest.mark.parametrize(
217217
"func,data",
218218
[
219-
pytest.param(
220-
"sqrt",
221-
[1.0, 3.0, 9.0],
222-
),
219+
pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
220+
pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
221+
pytest.param("trunc", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
222+
pytest.param("sqrt", [1.0, 3.0, 9.0]),
223223
],
224224
)
225225
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)

0 commit comments

Comments
 (0)