Skip to content

Commit 5bcf910

Browse files
authored
rework implementation of diag, diagflat, vander, and ptp (#1579)
* rework implementation of diag, diagflat, vander, and ptp * address comments - first round cherry-pick * address comments - second round * add tests for negative use cases to improve covergae * fixed missing merge conflicts * fix pre-commit
1 parent 630dae2 commit 5bcf910

16 files changed

+590
-485
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

+32-37
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,7 @@ enum class DPNPFuncName : size_t
132132
DPNP_FN_DET_EXT, /**< Used in numpy.linalg.det() impl, requires extra
133133
parameters */
134134
DPNP_FN_DIAG, /**< Used in numpy.diag() impl */
135-
DPNP_FN_DIAG_EXT, /**< Used in numpy.diag() impl, requires extra parameters
136-
*/
137-
DPNP_FN_DIAG_INDICES, /**< Used in numpy.diag_indices() impl */
135+
DPNP_FN_DIAG_INDICES, /**< Used in numpy.diag_indices() impl */
138136
DPNP_FN_DIAG_INDICES_EXT, /**< Used in numpy.diag_indices() impl, requires
139137
extra parameters */
140138
DPNP_FN_DIAGONAL, /**< Used in numpy.diagonal() impl */
@@ -225,25 +223,24 @@ enum class DPNPFuncName : size_t
225223
DPNP_FN_MODF_EXT, /**< Used in numpy.modf() impl, requires extra parameters
226224
*/
227225
DPNP_FN_MULTIPLY, /**< Used in numpy.multiply() impl */
228-
DPNP_FN_MULTIPLY_EXT, /**< Used in numpy.multiply() impl, requires extra
229-
parameters */
230-
DPNP_FN_NANVAR, /**< Used in numpy.nanvar() impl */
231-
DPNP_FN_NANVAR_EXT, /**< Used in numpy.nanvar() impl, requires extra
232-
parameters */
233-
DPNP_FN_NEGATIVE, /**< Used in numpy.negative() impl */
234-
DPNP_FN_NONZERO, /**< Used in numpy.nonzero() impl */
235-
DPNP_FN_ONES, /**< Used in numpy.ones() impl */
236-
DPNP_FN_ONES_LIKE, /**< Used in numpy.ones_like() impl */
237-
DPNP_FN_PARTITION, /**< Used in numpy.partition() impl */
238-
DPNP_FN_PARTITION_EXT, /**< Used in numpy.partition() impl, requires extra
239-
parameters */
240-
DPNP_FN_PLACE, /**< Used in numpy.place() impl */
241-
DPNP_FN_POWER, /**< Used in numpy.power() impl */
242-
DPNP_FN_PROD, /**< Used in numpy.prod() impl */
243-
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
244-
DPNP_FN_PTP_EXT, /**< Used in numpy.ptp() impl, requires extra parameters */
245-
DPNP_FN_PUT, /**< Used in numpy.put() impl */
246-
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
226+
DPNP_FN_MULTIPLY_EXT, /**< Used in numpy.multiply() impl, requires extra
227+
parameters */
228+
DPNP_FN_NANVAR, /**< Used in numpy.nanvar() impl */
229+
DPNP_FN_NANVAR_EXT, /**< Used in numpy.nanvar() impl, requires extra
230+
parameters */
231+
DPNP_FN_NEGATIVE, /**< Used in numpy.negative() impl */
232+
DPNP_FN_NONZERO, /**< Used in numpy.nonzero() impl */
233+
DPNP_FN_ONES, /**< Used in numpy.ones() impl */
234+
DPNP_FN_ONES_LIKE, /**< Used in numpy.ones_like() impl */
235+
DPNP_FN_PARTITION, /**< Used in numpy.partition() impl */
236+
DPNP_FN_PARTITION_EXT, /**< Used in numpy.partition() impl, requires extra
237+
parameters */
238+
DPNP_FN_PLACE, /**< Used in numpy.place() impl */
239+
DPNP_FN_POWER, /**< Used in numpy.power() impl */
240+
DPNP_FN_PROD, /**< Used in numpy.prod() impl */
241+
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
242+
DPNP_FN_PUT, /**< Used in numpy.put() impl */
243+
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
247244
DPNP_FN_PUT_ALONG_AXIS_EXT, /**< Used in numpy.put_along_axis() impl,
248245
requires extra parameters */
249246
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
@@ -401,21 +398,19 @@ enum class DPNPFuncName : size_t
401398
DPNP_FN_TAKE, /**< Used in numpy.take() impl */
402399
DPNP_FN_TAN, /**< Used in numpy.tan() impl */
403400
DPNP_FN_TANH, /**< Used in numpy.tanh() impl */
404-
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */
405-
DPNP_FN_TRACE, /**< Used in numpy.trace() impl */
406-
DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra
407-
parameters */
408-
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */
409-
DPNP_FN_TRAPZ_EXT, /**< Used in numpy.trapz() impl, requires extra
410-
parameters */
411-
DPNP_FN_TRI, /**< Used in numpy.tri() impl */
412-
DPNP_FN_TRIL, /**< Used in numpy.tril() impl */
413-
DPNP_FN_TRIU, /**< Used in numpy.triu() impl */
414-
DPNP_FN_TRUNC, /**< Used in numpy.trunc() impl */
415-
DPNP_FN_VANDER, /**< Used in numpy.vander() impl */
416-
DPNP_FN_VANDER_EXT, /**< Used in numpy.vander() impl, requires extra
417-
parameters */
418-
DPNP_FN_VAR, /**< Used in numpy.var() impl */
401+
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */
402+
DPNP_FN_TRACE, /**< Used in numpy.trace() impl */
403+
DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra
404+
parameters */
405+
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */
406+
DPNP_FN_TRAPZ_EXT, /**< Used in numpy.trapz() impl, requires extra
407+
parameters */
408+
DPNP_FN_TRI, /**< Used in numpy.tri() impl */
409+
DPNP_FN_TRIL, /**< Used in numpy.tril() impl */
410+
DPNP_FN_TRIU, /**< Used in numpy.triu() impl */
411+
DPNP_FN_TRUNC, /**< Used in numpy.trunc() impl */
412+
DPNP_FN_VANDER, /**< Used in numpy.vander() impl */
413+
DPNP_FN_VAR, /**< Used in numpy.var() impl */
419414
DPNP_FN_VAR_EXT, /**< Used in numpy.var() impl, requires extra parameters */
420415
DPNP_FN_ZEROS, /**< Used in numpy.zeros() impl */
421416
DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() impl */

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

-74
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,6 @@ void (*dpnp_diag_default_c)(void *,
200200
const size_t,
201201
const size_t) = dpnp_diag_c<_DataType>;
202202

203-
template <typename _DataType>
204-
DPCTLSyclEventRef (*dpnp_diag_ext_c)(DPCTLSyclQueueRef,
205-
void *,
206-
void *,
207-
const int,
208-
shape_elem_type *,
209-
shape_elem_type *,
210-
const size_t,
211-
const size_t,
212-
const DPCTLEventVectorRef) =
213-
dpnp_diag_c<_DataType>;
214-
215203
template <typename _DataType>
216204
DPCTLSyclEventRef dpnp_eye_c(DPCTLSyclQueueRef q_ref,
217205
void *result1,
@@ -569,23 +557,6 @@ void (*dpnp_ptp_default_c)(void *,
569557
const shape_elem_type *,
570558
const size_t) = dpnp_ptp_c<_DataType>;
571559

572-
template <typename _DataType>
573-
DPCTLSyclEventRef (*dpnp_ptp_ext_c)(DPCTLSyclQueueRef,
574-
void *,
575-
const size_t,
576-
const size_t,
577-
const shape_elem_type *,
578-
const shape_elem_type *,
579-
const void *,
580-
const size_t,
581-
const size_t,
582-
const shape_elem_type *,
583-
const shape_elem_type *,
584-
const shape_elem_type *,
585-
const size_t,
586-
const DPCTLEventVectorRef) =
587-
dpnp_ptp_c<_DataType>;
588-
589560
template <typename _DataType_input, typename _DataType_output>
590561
DPCTLSyclEventRef dpnp_vander_c(DPCTLSyclQueueRef q_ref,
591562
const void *array1_in,
@@ -673,16 +644,6 @@ void (*dpnp_vander_default_c)(const void *,
673644
const int) =
674645
dpnp_vander_c<_DataType_input, _DataType_output>;
675646

676-
template <typename _DataType_input, typename _DataType_output>
677-
DPCTLSyclEventRef (*dpnp_vander_ext_c)(DPCTLSyclQueueRef,
678-
const void *,
679-
void *,
680-
const size_t,
681-
const size_t,
682-
const int,
683-
const DPCTLEventVectorRef) =
684-
dpnp_vander_c<_DataType_input, _DataType_output>;
685-
686647
template <typename _DataType, typename _ResultType>
687648
class dpnp_trace_c_kernel;
688649

@@ -1192,15 +1153,6 @@ void func_map_init_arraycreation(func_map_t &fmap)
11921153
fmap[DPNPFuncName::DPNP_FN_DIAG][eft_DBL][eft_DBL] = {
11931154
eft_DBL, (void *)dpnp_diag_default_c<double>};
11941155

1195-
fmap[DPNPFuncName::DPNP_FN_DIAG_EXT][eft_INT][eft_INT] = {
1196-
eft_INT, (void *)dpnp_diag_ext_c<int32_t>};
1197-
fmap[DPNPFuncName::DPNP_FN_DIAG_EXT][eft_LNG][eft_LNG] = {
1198-
eft_LNG, (void *)dpnp_diag_ext_c<int64_t>};
1199-
fmap[DPNPFuncName::DPNP_FN_DIAG_EXT][eft_FLT][eft_FLT] = {
1200-
eft_FLT, (void *)dpnp_diag_ext_c<float>};
1201-
fmap[DPNPFuncName::DPNP_FN_DIAG_EXT][eft_DBL][eft_DBL] = {
1202-
eft_DBL, (void *)dpnp_diag_ext_c<double>};
1203-
12041156
fmap[DPNPFuncName::DPNP_FN_EYE][eft_INT][eft_INT] = {
12051157
eft_INT, (void *)dpnp_eye_default_c<int32_t>};
12061158
fmap[DPNPFuncName::DPNP_FN_EYE][eft_LNG][eft_LNG] = {
@@ -1284,15 +1236,6 @@ void func_map_init_arraycreation(func_map_t &fmap)
12841236
fmap[DPNPFuncName::DPNP_FN_PTP][eft_DBL][eft_DBL] = {
12851237
eft_DBL, (void *)dpnp_ptp_default_c<double>};
12861238

1287-
fmap[DPNPFuncName::DPNP_FN_PTP_EXT][eft_INT][eft_INT] = {
1288-
eft_INT, (void *)dpnp_ptp_ext_c<int32_t>};
1289-
fmap[DPNPFuncName::DPNP_FN_PTP_EXT][eft_LNG][eft_LNG] = {
1290-
eft_LNG, (void *)dpnp_ptp_ext_c<int64_t>};
1291-
fmap[DPNPFuncName::DPNP_FN_PTP_EXT][eft_FLT][eft_FLT] = {
1292-
eft_FLT, (void *)dpnp_ptp_ext_c<float>};
1293-
fmap[DPNPFuncName::DPNP_FN_PTP_EXT][eft_DBL][eft_DBL] = {
1294-
eft_DBL, (void *)dpnp_ptp_ext_c<double>};
1295-
12961239
fmap[DPNPFuncName::DPNP_FN_VANDER][eft_INT][eft_INT] = {
12971240
eft_LNG, (void *)dpnp_vander_default_c<int32_t, int64_t>};
12981241
fmap[DPNPFuncName::DPNP_FN_VANDER][eft_LNG][eft_LNG] = {
@@ -1308,23 +1251,6 @@ void func_map_init_arraycreation(func_map_t &fmap)
13081251
(void *)
13091252
dpnp_vander_default_c<std::complex<double>, std::complex<double>>};
13101253

1311-
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_INT][eft_INT] = {
1312-
eft_LNG, (void *)dpnp_vander_ext_c<int32_t, int64_t>};
1313-
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_LNG][eft_LNG] = {
1314-
eft_LNG, (void *)dpnp_vander_ext_c<int64_t, int64_t>};
1315-
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_FLT][eft_FLT] = {
1316-
eft_FLT, (void *)dpnp_vander_ext_c<float, float>};
1317-
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_DBL][eft_DBL] = {
1318-
eft_DBL, (void *)dpnp_vander_ext_c<double, double>};
1319-
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_BLN][eft_BLN] = {
1320-
eft_LNG, (void *)dpnp_vander_ext_c<bool, int64_t>};
1321-
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_C64][eft_C64] = {
1322-
eft_C64,
1323-
(void *)dpnp_vander_ext_c<std::complex<float>, std::complex<float>>};
1324-
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_C128][eft_C128] = {
1325-
eft_C128,
1326-
(void *)dpnp_vander_ext_c<std::complex<double>, std::complex<double>>};
1327-
13281254
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_INT][eft_INT] = {
13291255
eft_INT, (void *)dpnp_trace_default_c<int32_t, int32_t>};
13301256
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_LNG][eft_INT] = {

dpnp/dpnp_algo/dpnp_algo.pxd

-6
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
6464
DPNP_FN_DEGREES_EXT
6565
DPNP_FN_DET
6666
DPNP_FN_DET_EXT
67-
DPNP_FN_DIAG
68-
DPNP_FN_DIAG_EXT
6967
DPNP_FN_DIAG_INDICES
7068
DPNP_FN_DIAG_INDICES_EXT
7169
DPNP_FN_DIAGONAL
@@ -120,8 +118,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
120118
DPNP_FN_PARTITION
121119
DPNP_FN_PARTITION_EXT
122120
DPNP_FN_PLACE
123-
DPNP_FN_PTP
124-
DPNP_FN_PTP_EXT
125121
DPNP_FN_QR
126122
DPNP_FN_QR_EXT
127123
DPNP_FN_RADIANS
@@ -218,8 +214,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
218214
DPNP_FN_TRIL_EXT
219215
DPNP_FN_TRIU
220216
DPNP_FN_TRIU_EXT
221-
DPNP_FN_VANDER
222-
DPNP_FN_VANDER_EXT
223217
DPNP_FN_VAR
224218
DPNP_FN_VAR_EXT
225219
DPNP_FN_ZEROS

0 commit comments

Comments
 (0)