Skip to content

Commit dec3e57

Browse files
committed
remove unneccessary parts with updates in dpctl #1465
1 parent e026c37 commit dec3e57

File tree

2 files changed

+14
-65
lines changed

2 files changed

+14
-65
lines changed

dpnp/dpnp_iface_searching.py

+6-32
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,9 @@ def argmax(a, axis=None, out=None, *, keepdims=False):
110110
"""
111111

112112
dpt_array = dpnp.get_usm_ndarray(a)
113-
if dpt_array.size == 0:
114-
# TODO: get rid of this if condition when dpctl supports it
115-
for i in range(a.ndim):
116-
if a.shape[i] == 0:
117-
if i == axis or axis is None:
118-
raise ValueError(
119-
"reduction does not support zero-size arrays"
120-
)
121-
else:
122-
indices = [i for i in range(a.ndim) if i != axis]
123-
res_shape = tuple([a.shape[i] for i in indices])
124-
result = dpnp.empty(res_shape, dtype=int)
125-
else:
126-
result = dpnp_array._create_from_usm_ndarray(
127-
dpt.argmax(dpt_array, axis=axis, keepdims=keepdims)
128-
)
113+
result = dpnp_array._create_from_usm_ndarray(
114+
dpt.argmax(dpt_array, axis=axis, keepdims=keepdims)
115+
)
129116

130117
if out is None:
131118
return result
@@ -210,22 +197,9 @@ def argmin(a, axis=None, out=None, *, keepdims=False):
210197
"""
211198

212199
dpt_array = dpnp.get_usm_ndarray(a)
213-
if dpt_array.size == 0:
214-
# TODO: get rid of this if condition when dpctl supports it
215-
for i in range(a.ndim):
216-
if a.shape[i] == 0:
217-
if i == axis or axis is None:
218-
raise ValueError(
219-
"reduction does not support zero-size arrays"
220-
)
221-
else:
222-
indices = [i for i in range(a.ndim) if i != axis]
223-
res_shape = tuple([a.shape[i] for i in indices])
224-
result = dpnp.empty(res_shape, dtype=int)
225-
else:
226-
result = dpnp_array._create_from_usm_ndarray(
227-
dpt.argmin(dpt_array, axis=axis, keepdims=keepdims)
228-
)
200+
result = dpnp_array._create_from_usm_ndarray(
201+
dpt.argmin(dpt_array, axis=axis, keepdims=keepdims)
202+
)
229203

230204
if out is None:
231205
return result

dpnp/dpnp_iface_statistics.py

+8-33
Original file line numberDiff line numberDiff line change
@@ -410,23 +410,10 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
410410
)
411411
else:
412412
dpt_array = dpnp.get_usm_ndarray(a)
413-
if dpt_array.size == 0:
414-
# TODO: get rid of this if condition when dpctl supports it
415-
axis = (axis,) if isinstance(axis, int) else axis
416-
for i in range(a.ndim):
417-
if a.shape[i] == 0:
418-
if axis is None or i in axis:
419-
raise ValueError(
420-
"reduction does not support zero-size arrays"
421-
)
422-
else:
423-
indices = [i for i in range(a.ndim) if i not in axis]
424-
res_shape = tuple([a.shape[i] for i in indices])
425-
result = dpnp.empty(res_shape, dtype=a.dtype)
426-
else:
427-
result = dpnp_array._create_from_usm_ndarray(
428-
dpt.max(dpt_array, axis=axis, keepdims=keepdims)
429-
)
413+
result = dpnp_array._create_from_usm_ndarray(
414+
dpt.max(dpt_array, axis=axis, keepdims=keepdims)
415+
)
416+
430417
if out is None:
431418
return result
432419
else:
@@ -655,22 +642,10 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
655642
)
656643
else:
657644
dpt_array = dpnp.get_usm_ndarray(a)
658-
if dpt_array.size == 0:
659-
# TODO: get rid of this if condition when dpctl supports it
660-
for i in range(a.ndim):
661-
if a.shape[i] == 0:
662-
if axis is None or i in axis:
663-
raise ValueError(
664-
"reduction does not support zero-size arrays"
665-
)
666-
else:
667-
indices = [i for i in range(a.ndim) if i not in axis]
668-
res_shape = tuple([a.shape[i] for i in indices])
669-
result = dpnp.empty(res_shape, dtype=a.dtype)
670-
else:
671-
result = dpnp_array._create_from_usm_ndarray(
672-
dpt.min(dpt_array, axis=axis, keepdims=keepdims)
673-
)
645+
result = dpnp_array._create_from_usm_ndarray(
646+
dpt.min(dpt_array, axis=axis, keepdims=keepdims)
647+
)
648+
674649
if out is None:
675650
return result
676651
else:

0 commit comments

Comments
 (0)