Skip to content

Commit f6b87a6

Browse files
committed
address reviewers' comments
1 parent 34c3703 commit f6b87a6

File tree

2 files changed

+48
-37
lines changed

2 files changed

+48
-37
lines changed

tests/test_mathematical.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -596,120 +596,131 @@ def test_gradient_y1_dx(self, array, dx):
596596
assert_array_equal(expected, result)
597597

598598

599+
if has_support_aspect64():
600+
dtype_list = [numpy.float32, numpy.int64, numpy.int32]
601+
ids_list = ["numpy.float32", "numpy.int64", "numpy.int32"]
602+
else:
603+
dtype_list = [numpy.int64, numpy.int32]
604+
ids_list = ["numpy.int64", "numpy.int32"]
605+
606+
599607
class TestCeil:
600608
def test_ceil(self):
601609
array_data = numpy.arange(10)
602-
out = numpy.empty(10, dtype=numpy.float64)
610+
out = numpy.empty(10, dtype=numpy.float32)
603611

604612
# DPNP
605-
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
606-
dp_out = dpnp.array(out, dtype=dpnp.float64)
613+
dp_array = dpnp.array(array_data, dtype=dpnp.float32)
614+
dp_out = dpnp.array(out, dtype=dpnp.float32)
607615
result = dpnp.ceil(dp_array, out=dp_out)
608616

609617
# original
610-
np_array = numpy.array(array_data, dtype=numpy.float64)
618+
np_array = numpy.array(array_data, dtype=numpy.float32)
611619
expected = numpy.ceil(np_array, out=out)
612620

613621
assert_array_equal(expected, result)
614622

615623
@pytest.mark.parametrize(
616624
"dtype",
617-
[numpy.float32, numpy.int64, numpy.int32],
618-
ids=["numpy.float32", "numpy.int64", "numpy.int32"],
625+
dtype_list,
626+
ids=ids_list,
619627
)
620628
def test_invalid_dtype(self, dtype):
621-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
629+
dpnp_dtype = dpnp.float64 if has_support_aspect64() else dpnp.float32
630+
dp_array = dpnp.arange(10, dtype=dpnp_dtype)
622631
dp_out = dpnp.empty(10, dtype=dtype)
623632

624-
with pytest.raises(ValueError):
633+
with pytest.raises(TypeError):
625634
dpnp.ceil(dp_array, out=dp_out)
626635

627636
@pytest.mark.parametrize(
628637
"shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"]
629638
)
630639
def test_invalid_shape(self, shape):
631-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
632-
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
640+
dp_array = dpnp.arange(10, dtype=dpnp.float32)
641+
dp_out = dpnp.empty(shape, dtype=dpnp.float32)
633642

634-
with pytest.raises(ValueError):
643+
with pytest.raises(TypeError):
635644
dpnp.ceil(dp_array, out=dp_out)
636645

637646

638647
class TestFloor:
639648
def test_floor(self):
640649
array_data = numpy.arange(10)
641-
out = numpy.empty(10, dtype=numpy.float64)
650+
out = numpy.empty(10, dtype=numpy.float32)
642651

643652
# DPNP
644-
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
645-
dp_out = dpnp.array(out, dtype=dpnp.float64)
653+
dp_array = dpnp.array(array_data, dtype=dpnp.float32)
654+
dp_out = dpnp.array(out, dtype=dpnp.float32)
646655
result = dpnp.floor(dp_array, out=dp_out)
647656

648657
# original
649-
np_array = numpy.array(array_data, dtype=numpy.float64)
658+
np_array = numpy.array(array_data, dtype=numpy.float32)
650659
expected = numpy.floor(np_array, out=out)
651660

652661
assert_array_equal(expected, result)
653662

654663
@pytest.mark.parametrize(
655664
"dtype",
656-
[numpy.float32, numpy.int64, numpy.int32],
657-
ids=["numpy.float32", "numpy.int64", "numpy.int32"],
665+
dtype_list,
666+
ids=ids_list,
658667
)
659668
def test_invalid_dtype(self, dtype):
660-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
669+
dpnp_dtype = dpnp.float64 if has_support_aspect64() else dpnp.float32
670+
dp_array = dpnp.arange(10, dtype=dpnp_dtype)
661671
dp_out = dpnp.empty(10, dtype=dtype)
662672

663-
with pytest.raises(ValueError):
673+
with pytest.raises(TypeError):
664674
dpnp.floor(dp_array, out=dp_out)
665675

666676
@pytest.mark.parametrize(
667677
"shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"]
668678
)
669679
def test_invalid_shape(self, shape):
670-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
671-
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
680+
dp_array = dpnp.arange(10, dtype=dpnp.float32)
681+
dp_out = dpnp.empty(shape, dtype=dpnp.float32)
672682

673-
with pytest.raises(ValueError):
683+
with pytest.raises(TypeError):
674684
dpnp.floor(dp_array, out=dp_out)
675685

676686

677687
class TestTrunc:
678688
def test_trunc(self):
679689
array_data = numpy.arange(10)
680-
out = numpy.empty(10, dtype=numpy.float64)
690+
out = numpy.empty(10, dtype=numpy.float32)
681691

682692
# DPNP
683-
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
684-
dp_out = dpnp.array(out, dtype=dpnp.float64)
693+
dp_array = dpnp.array(array_data, dtype=dpnp.float32)
694+
dp_out = dpnp.array(out, dtype=dpnp.float32)
685695
result = dpnp.trunc(dp_array, out=dp_out)
686696

687697
# original
688-
np_array = numpy.array(array_data, dtype=numpy.float64)
698+
np_array = numpy.array(array_data, dtype=numpy.float32)
689699
expected = numpy.trunc(np_array, out=out)
690700

691701
assert_array_equal(expected, result)
692702

693703
@pytest.mark.parametrize(
694704
"dtype",
695-
[numpy.float32, numpy.int64, numpy.int32],
696-
ids=["numpy.float32", "numpy.int64", "numpy.int32"],
705+
dtype_list,
706+
ids=ids_list,
697707
)
698708
def test_invalid_dtype(self, dtype):
699-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
709+
dpnp_dtype = dpnp.float64 if has_support_aspect64() else dpnp.float32
710+
dp_array = dpnp.arange(10, dtype=dpnp_dtype)
700711
dp_out = dpnp.empty(10, dtype=dtype)
701712

702-
with pytest.raises(ValueError):
713+
with pytest.raises(TypeError):
703714
dpnp.trunc(dp_array, out=dp_out)
704715

705716
@pytest.mark.parametrize(
706717
"shape", [(0,), (15,), (2, 2)], ids=["(0,)", "(15, )", "(2,2)"]
707718
)
708719
def test_invalid_shape(self, shape):
709-
dp_array = dpnp.arange(10, dtype=dpnp.float64)
710-
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
720+
dp_array = dpnp.arange(10, dtype=dpnp.float32)
721+
dp_out = dpnp.empty(shape, dtype=dpnp.float32)
711722

712-
with pytest.raises(ValueError):
723+
with pytest.raises(TypeError):
713724
dpnp.trunc(dp_array, out=dp_out)
714725

715726

tests/test_usm_type.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ def test_meshgrid(usm_type_x, usm_type_y):
216216
@pytest.mark.parametrize(
217217
"func,data",
218218
[
219-
pytest.param(
220-
"sqrt",
221-
[1.0, 3.0, 9.0],
222-
),
219+
pytest.param("ceil", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
220+
pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
221+
pytest.param("trunc", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]),
222+
pytest.param("sqrt", [1.0, 3.0, 9.0]),
223223
],
224224
)
225225
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)

0 commit comments

Comments
 (0)