Skip to content

Commit 269c03f

Browse files
committed
updates for nanprod input array
1 parent c2a5fe4 commit 269c03f

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

dpnp/dpnp_iface_mathematical.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1799,14 +1799,15 @@ def nanprod(
17991799
18001800
"""
18011801

1802-
if issubclass(a.dtype.type, dpnp.inexact):
1803-
mask = dpnp.isnan(a)
1802+
if dpnp.is_supported_array_or_scalar(a):
1803+
if issubclass(a.dtype.type, dpnp.inexact):
1804+
mask = dpnp.isnan(a)
1805+
a = dpnp.array(a, copy=True)
1806+
dpnp.copyto(a, 1, where=mask)
18041807
else:
1805-
mask = None
1806-
1807-
if mask is not None:
1808-
a = dpnp.array(a, copy=True)
1809-
dpnp.copyto(a, 1, where=mask)
1808+
raise TypeError(
1809+
"An array must be any of supported type, but got {}".format(type(a))
1810+
)
18101811

18111812
return dpnp.prod(
18121813
a,

tests/test_mathematical.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -549,11 +549,13 @@ def test_prod_nanprod_out(func):
549549
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
550550

551551

552-
def test_prod_Error():
552+
def test_prod_nanprod_Error():
553553
ia = dpnp.arange(5)
554554

555555
with pytest.raises(TypeError):
556556
dpnp.prod(dpnp.asnumpy(ia))
557+
with pytest.raises(TypeError):
558+
dpnp.nanprod(dpnp.asnumpy(ia))
557559
with pytest.raises(NotImplementedError):
558560
dpnp.prod(ia, where=False)
559561
with pytest.raises(NotImplementedError):

0 commit comments

Comments
 (0)