Skip to content

Commit 7d0815f

Browse files
vtavanaantonwolfy
andauthored
implement dpnp.argmin and dpnp.argmax using dpctl.tensor (#1610)
* rework implementation of diag, diagflat, vander, and ptp * address comments - first round cherry-pick * address comments - second round * add tests for negative use cases to improve covergae * fixed missing merge conflicts * fix pre-commit * implement dpnp.argmin and dpnp.argmax using dpctl.tensor * address comments * add tests for negative use cases to improve coverage * remove unneccessary parts with updates in dpctl #1465 * add paramater section in doc * update ndarray.argmin and ndarray.argmax function signature * use a utility func for returning output * add tests for ndarray implementation * Place new function acc to lexicographical order --------- Co-authored-by: Anton Volkov <[email protected]> Co-authored-by: Anton <[email protected]>
1 parent 525116a commit 7d0815f

18 files changed

+346
-392
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

-4
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,7 @@ enum class DPNPFuncName : size_t
7676
DPNP_FN_ARCTAN2, /**< Used in numpy.arctan2() impl */
7777
DPNP_FN_ARCTANH, /**< Used in numpy.arctanh() impl */
7878
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() impl */
79-
DPNP_FN_ARGMAX_EXT, /**< Used in numpy.argmax() impl, requires extra
80-
parameters */
8179
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() impl */
82-
DPNP_FN_ARGMIN_EXT, /**< Used in numpy.argmin() impl, requires extra
83-
parameters */
8480
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() impl */
8581
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
8682
parameters */

dpnp/backend/kernels/dpnp_krnl_searching.cpp

-50
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,6 @@ void (*dpnp_argmax_default_c)(void *,
7878
void *,
7979
size_t) = dpnp_argmax_c<_DataType, _idx_DataType>;
8080

81-
template <typename _DataType, typename _idx_DataType>
82-
DPCTLSyclEventRef (*dpnp_argmax_ext_c)(DPCTLSyclQueueRef,
83-
void *,
84-
void *,
85-
size_t,
86-
const DPCTLEventVectorRef) =
87-
dpnp_argmax_c<_DataType, _idx_DataType>;
88-
8981
template <typename _DataType, typename _idx_DataType>
9082
class dpnp_argmin_c_kernel;
9183

@@ -133,14 +125,6 @@ void (*dpnp_argmin_default_c)(void *,
133125
void *,
134126
size_t) = dpnp_argmin_c<_DataType, _idx_DataType>;
135127

136-
template <typename _DataType, typename _idx_DataType>
137-
DPCTLSyclEventRef (*dpnp_argmin_ext_c)(DPCTLSyclQueueRef,
138-
void *,
139-
void *,
140-
size_t,
141-
const DPCTLEventVectorRef) =
142-
dpnp_argmin_c<_DataType, _idx_DataType>;
143-
144128
void func_map_init_searching(func_map_t &fmap)
145129
{
146130
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_INT][eft_INT] = {
@@ -160,23 +144,6 @@ void func_map_init_searching(func_map_t &fmap)
160144
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_DBL][eft_LNG] = {
161145
eft_LNG, (void *)dpnp_argmax_default_c<double, int64_t>};
162146

163-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_INT][eft_INT] = {
164-
eft_INT, (void *)dpnp_argmax_ext_c<int32_t, int32_t>};
165-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_INT][eft_LNG] = {
166-
eft_LNG, (void *)dpnp_argmax_ext_c<int32_t, int64_t>};
167-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_LNG][eft_INT] = {
168-
eft_INT, (void *)dpnp_argmax_ext_c<int64_t, int32_t>};
169-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_LNG][eft_LNG] = {
170-
eft_LNG, (void *)dpnp_argmax_ext_c<int64_t, int64_t>};
171-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_FLT][eft_INT] = {
172-
eft_INT, (void *)dpnp_argmax_ext_c<float, int32_t>};
173-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_FLT][eft_LNG] = {
174-
eft_LNG, (void *)dpnp_argmax_ext_c<float, int64_t>};
175-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_DBL][eft_INT] = {
176-
eft_INT, (void *)dpnp_argmax_ext_c<double, int32_t>};
177-
fmap[DPNPFuncName::DPNP_FN_ARGMAX_EXT][eft_DBL][eft_LNG] = {
178-
eft_LNG, (void *)dpnp_argmax_ext_c<double, int64_t>};
179-
180147
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_INT][eft_INT] = {
181148
eft_INT, (void *)dpnp_argmin_default_c<int32_t, int32_t>};
182149
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_INT][eft_LNG] = {
@@ -194,22 +161,5 @@ void func_map_init_searching(func_map_t &fmap)
194161
fmap[DPNPFuncName::DPNP_FN_ARGMIN][eft_DBL][eft_LNG] = {
195162
eft_LNG, (void *)dpnp_argmin_default_c<double, int64_t>};
196163

197-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_INT][eft_INT] = {
198-
eft_INT, (void *)dpnp_argmin_ext_c<int32_t, int32_t>};
199-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_INT][eft_LNG] = {
200-
eft_LNG, (void *)dpnp_argmin_ext_c<int32_t, int64_t>};
201-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_LNG][eft_INT] = {
202-
eft_INT, (void *)dpnp_argmin_ext_c<int64_t, int32_t>};
203-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_LNG][eft_LNG] = {
204-
eft_LNG, (void *)dpnp_argmin_ext_c<int64_t, int64_t>};
205-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_FLT][eft_INT] = {
206-
eft_INT, (void *)dpnp_argmin_ext_c<float, int32_t>};
207-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_FLT][eft_LNG] = {
208-
eft_LNG, (void *)dpnp_argmin_ext_c<float, int64_t>};
209-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_INT] = {
210-
eft_INT, (void *)dpnp_argmin_ext_c<double, int32_t>};
211-
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_LNG] = {
212-
eft_LNG, (void *)dpnp_argmin_ext_c<double, int64_t>};
213-
214164
return;
215165
}

dpnp/dpnp_algo/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ set(dpnp_algo_pyx_deps
66
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_sorting.pxi
77
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_arraycreation.pxi
88
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_mathematical.pxi
9-
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_searching.pxi
109
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_indexing.pxi
1110
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_logic.pxi
1211
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_special.pxi

dpnp/dpnp_algo/dpnp_algo.pxd

-10
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3636
DPNP_FN_ALLCLOSE
3737
DPNP_FN_ALLCLOSE_EXT
3838
DPNP_FN_ARANGE
39-
DPNP_FN_ARGMAX
40-
DPNP_FN_ARGMAX_EXT
41-
DPNP_FN_ARGMIN
42-
DPNP_FN_ARGMIN_EXT
4339
DPNP_FN_ARGSORT
4440
DPNP_FN_ARGSORT_EXT
4541
DPNP_FN_CBRT
@@ -355,12 +351,6 @@ Sorting functions
355351
cpdef dpnp_descriptor dpnp_argsort(dpnp_descriptor array1)
356352
cpdef dpnp_descriptor dpnp_sort(dpnp_descriptor array1)
357353

358-
"""
359-
Searching functions
360-
"""
361-
cpdef dpnp_descriptor dpnp_argmax(dpnp_descriptor array1)
362-
cpdef dpnp_descriptor dpnp_argmin(dpnp_descriptor array1)
363-
364354
"""
365355
Trigonometric functions
366356
"""

dpnp/dpnp_algo/dpnp_algo.pyx

-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ include "dpnp_algo_indexing.pxi"
6363
include "dpnp_algo_linearalgebra.pxi"
6464
include "dpnp_algo_logic.pxi"
6565
include "dpnp_algo_mathematical.pxi"
66-
include "dpnp_algo_searching.pxi"
6766
include "dpnp_algo_sorting.pxi"
6867
include "dpnp_algo_special.pxi"
6968
include "dpnp_algo_statistics.pxi"

dpnp/dpnp_algo/dpnp_algo_searching.pxi

-119
This file was deleted.

dpnp/dpnp_array.py

+6-41
Original file line numberDiff line numberDiff line change
@@ -486,58 +486,23 @@ def any(self, axis=None, out=None, keepdims=False, *, where=True):
486486
self, axis=axis, out=out, keepdims=keepdims, where=where
487487
)
488488

489-
def argmax(self, axis=None, out=None):
489+
def argmax(self, axis=None, out=None, *, keepdims=False):
490490
"""
491491
Returns array of indices of the maximum values along the given axis.
492492
493-
Parameters
494-
----------
495-
axis : {None, integer}
496-
If None, the index is into the flattened array, otherwise along
497-
the specified axis
498-
out : {None, array}, optional
499-
Array into which the result can be placed. Its type is preserved
500-
and it must be of the right shape to hold the output.
501-
502-
Returns
503-
-------
504-
index_array : {integer_array}
505-
506-
Examples
507-
--------
508-
>>> a = np.arange(6).reshape(2,3)
509-
>>> a.argmax()
510-
5
511-
>>> a.argmax(0)
512-
array([1, 1, 1])
513-
>>> a.argmax(1)
514-
array([2, 2])
493+
Refer to :obj:`dpnp.argmax` for full documentation.
515494
516495
"""
517-
return dpnp.argmax(self, axis, out)
496+
return dpnp.argmax(self, axis, out, keepdims=keepdims)
518497

519-
def argmin(self, axis=None, out=None):
498+
def argmin(self, axis=None, out=None, *, keepdims=False):
520499
"""
521500
Return array of indices to the minimum values along the given axis.
522501
523-
Parameters
524-
----------
525-
axis : {None, integer}
526-
If None, the index is into the flattened array, otherwise along
527-
the specified axis
528-
out : {None, array}, optional
529-
Array into which the result can be placed. Its type is preserved
530-
and it must be of the right shape to hold the output.
531-
532-
Returns
533-
-------
534-
ndarray or scalar
535-
If multi-dimension input, returns a new ndarray of indices to the
536-
minimum values along the given axis. Otherwise, returns a scalar
537-
of index to the minimum values along the given axis.
502+
Refer to :obj:`dpnp.argmin` for full documentation.
538503
539504
"""
540-
return dpnp.argmin(self, axis, out)
505+
return dpnp.argmin(self, axis, out, keepdims=keepdims)
541506

542507
# 'argpartition',
543508

dpnp/dpnp_iface.py

+46
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"get_dpnp_descriptor",
6666
"get_include",
6767
"get_normalized_queue_device",
68+
"get_result_array",
6869
"get_usm_ndarray",
6970
"get_usm_ndarray_or_scalar",
7071
"is_supported_array_or_scalar",
@@ -418,6 +419,51 @@ def get_normalized_queue_device(obj=None, device=None, sycl_queue=None):
418419
)
419420

420421

422+
def get_result_array(a, out=None):
423+
"""
424+
If `out` is provided, value of `a` array will be copied into the
425+
`out` array according to ``safe`` casting rule.
426+
Otherwise, the input array `a` is returned.
427+
428+
Parameters
429+
----------
430+
a : {dpnp_array}
431+
Input array.
432+
433+
out : {dpnp_array, usm_ndarray}
434+
If provided, value of `a` array will be copied into it
435+
according to ``safe`` casting rule.
436+
It should be of the appropriate shape.
437+
438+
Returns
439+
-------
440+
out : {dpnp_array}
441+
Return `out` if provided, otherwise return `a`.
442+
443+
"""
444+
445+
if out is None:
446+
return a
447+
else:
448+
if out.shape != a.shape:
449+
raise ValueError(
450+
f"Output array of shape {a.shape} is needed, got {out.shape}."
451+
)
452+
elif not isinstance(out, dpnp_array):
453+
if isinstance(out, dpt.usm_ndarray):
454+
out = dpnp_array._create_from_usm_ndarray(out)
455+
else:
456+
raise TypeError(
457+
"Output array must be any of supported type, but got {}".format(
458+
type(out)
459+
)
460+
)
461+
462+
dpnp.copyto(out, a, casting="safe")
463+
464+
return out
465+
466+
421467
def get_usm_ndarray(a):
422468
"""
423469
Return :class:`dpctl.tensor.usm_ndarray` from input array `a`.

0 commit comments

Comments
 (0)