Skip to content

Commit 771653b

Browse files
Update test_linalg.py to run on Iris Xe (#1474)
* Update cholesky function * Update dpnp.linalg.eig function * Update dpnp.linalg.eigvals * Update dpnp.linalg.inv() * Update dpnp_norm * Update dpnp.linalg.qr * Update dpnp.linalg.svd * Rename and move get_res_type_with_aspect func * dpnp_inv should return float when got float type
1 parent cfac723 commit 771653b

File tree

8 files changed

+180
-149
lines changed

8 files changed

+180
-149
lines changed

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,9 +1152,21 @@ void func_map_init_linalg(func_map_t &fmap)
11521152
eft_DBL, (void *)dpnp_eig_default_c<double, double>};
11531153

11541154
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_INT][eft_INT] = {
1155-
eft_DBL, (void *)dpnp_eig_ext_c<int32_t, double>};
1155+
get_default_floating_type<>(),
1156+
(void *)dpnp_eig_ext_c<
1157+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1158+
get_default_floating_type<std::false_type>(),
1159+
(void *)dpnp_eig_ext_c<
1160+
int32_t, func_type_map_t::find_type<
1161+
get_default_floating_type<std::false_type>()>>};
11561162
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_LNG][eft_LNG] = {
1157-
eft_DBL, (void *)dpnp_eig_ext_c<int64_t, double>};
1163+
get_default_floating_type<>(),
1164+
(void *)dpnp_eig_ext_c<
1165+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1166+
get_default_floating_type<std::false_type>(),
1167+
(void *)dpnp_eig_ext_c<
1168+
int64_t, func_type_map_t::find_type<
1169+
get_default_floating_type<std::false_type>()>>};
11581170
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_FLT][eft_FLT] = {
11591171
eft_FLT, (void *)dpnp_eig_ext_c<float, float>};
11601172
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_DBL][eft_DBL] = {
@@ -1170,9 +1182,21 @@ void func_map_init_linalg(func_map_t &fmap)
11701182
eft_DBL, (void *)dpnp_eigvals_default_c<double, double>};
11711183

11721184
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_INT][eft_INT] = {
1173-
eft_DBL, (void *)dpnp_eigvals_ext_c<int32_t, double>};
1185+
get_default_floating_type<>(),
1186+
(void *)dpnp_eigvals_ext_c<
1187+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1188+
get_default_floating_type<std::false_type>(),
1189+
(void *)dpnp_eigvals_ext_c<
1190+
int32_t, func_type_map_t::find_type<
1191+
get_default_floating_type<std::false_type>()>>};
11741192
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_LNG][eft_LNG] = {
1175-
eft_DBL, (void *)dpnp_eigvals_ext_c<int64_t, double>};
1193+
get_default_floating_type<>(),
1194+
(void *)dpnp_eigvals_ext_c<
1195+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1196+
get_default_floating_type<std::false_type>(),
1197+
(void *)dpnp_eigvals_ext_c<
1198+
int64_t, func_type_map_t::find_type<
1199+
get_default_floating_type<std::false_type>()>>};
11761200
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_FLT][eft_FLT] = {
11771201
eft_FLT, (void *)dpnp_eigvals_ext_c<float, float>};
11781202
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_DBL][eft_DBL] = {

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -874,16 +874,28 @@ void func_map_init_linalg_func(func_map_t &fmap)
874874
fmap[DPNPFuncName::DPNP_FN_INV][eft_LNG][eft_LNG] = {
875875
eft_DBL, (void *)dpnp_inv_default_c<int64_t, double>};
876876
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {
877-
eft_DBL, (void *)dpnp_inv_default_c<float, double>};
877+
eft_DBL, (void *)dpnp_inv_default_c<float, float>};
878878
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {
879879
eft_DBL, (void *)dpnp_inv_default_c<double, double>};
880880

881881
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_INT][eft_INT] = {
882-
eft_DBL, (void *)dpnp_inv_ext_c<int32_t, double>};
882+
get_default_floating_type<>(),
883+
(void *)dpnp_inv_ext_c<
884+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
885+
get_default_floating_type<std::false_type>(),
886+
(void *)dpnp_inv_ext_c<
887+
int32_t, func_type_map_t::find_type<
888+
get_default_floating_type<std::false_type>()>>};
883889
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_LNG][eft_LNG] = {
884-
eft_DBL, (void *)dpnp_inv_ext_c<int64_t, double>};
890+
get_default_floating_type<>(),
891+
(void *)dpnp_inv_ext_c<
892+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
893+
get_default_floating_type<std::false_type>(),
894+
(void *)dpnp_inv_ext_c<
895+
int64_t, func_type_map_t::find_type<
896+
get_default_floating_type<std::false_type>()>>};
885897
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_FLT][eft_FLT] = {
886-
eft_DBL, (void *)dpnp_inv_ext_c<float, double>};
898+
eft_FLT, (void *)dpnp_inv_ext_c<float, float>};
887899
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_DBL][eft_DBL] = {
888900
eft_DBL, (void *)dpnp_inv_ext_c<double, double>};
889901

@@ -1039,9 +1051,21 @@ void func_map_init_linalg_func(func_map_t &fmap)
10391051
// eft_C128, (void*)dpnp_qr_c<std::complex<double>, std::complex<double>>};
10401052

10411053
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_INT][eft_INT] = {
1042-
eft_DBL, (void *)dpnp_qr_ext_c<int32_t, double>};
1054+
get_default_floating_type<>(),
1055+
(void *)dpnp_qr_ext_c<
1056+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1057+
get_default_floating_type<std::false_type>(),
1058+
(void *)dpnp_qr_ext_c<
1059+
int32_t, func_type_map_t::find_type<
1060+
get_default_floating_type<std::false_type>()>>};
10431061
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_LNG][eft_LNG] = {
1044-
eft_DBL, (void *)dpnp_qr_ext_c<int64_t, double>};
1062+
get_default_floating_type<>(),
1063+
(void *)dpnp_qr_ext_c<
1064+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1065+
get_default_floating_type<std::false_type>(),
1066+
(void *)dpnp_qr_ext_c<
1067+
int64_t, func_type_map_t::find_type<
1068+
get_default_floating_type<std::false_type>()>>};
10451069
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_FLT][eft_FLT] = {
10461070
eft_FLT, (void *)dpnp_qr_ext_c<float, float>};
10471071
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_DBL][eft_DBL] = {
@@ -1062,9 +1086,29 @@ void func_map_init_linalg_func(func_map_t &fmap)
10621086
std::complex<double>, double>};
10631087

10641088
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_INT][eft_INT] = {
1065-
eft_DBL, (void *)dpnp_svd_ext_c<int32_t, double, double>};
1089+
get_default_floating_type<>(),
1090+
(void *)dpnp_svd_ext_c<
1091+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>,
1092+
func_type_map_t::find_type<get_default_floating_type<>()>>,
1093+
get_default_floating_type<std::false_type>(),
1094+
(void *)
1095+
dpnp_svd_ext_c<int32_t,
1096+
func_type_map_t::find_type<
1097+
get_default_floating_type<std::false_type>()>,
1098+
func_type_map_t::find_type<
1099+
get_default_floating_type<std::false_type>()>>};
10661100
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_LNG][eft_LNG] = {
1067-
eft_DBL, (void *)dpnp_svd_ext_c<int64_t, double, double>};
1101+
get_default_floating_type<>(),
1102+
(void *)dpnp_svd_ext_c<
1103+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>,
1104+
func_type_map_t::find_type<get_default_floating_type<>()>>,
1105+
get_default_floating_type<std::false_type>(),
1106+
(void *)
1107+
dpnp_svd_ext_c<int64_t,
1108+
func_type_map_t::find_type<
1109+
get_default_floating_type<std::false_type>()>,
1110+
func_type_map_t::find_type<
1111+
get_default_floating_type<std::false_type>()>>};
10681112
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_FLT][eft_FLT] = {
10691113
eft_FLT, (void *)dpnp_svd_ext_c<float, float, float>};
10701114
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_DBL][eft_DBL] = {

dpnp/backend/src/dpnp_fptr.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,17 @@ class dpnp_less_comp
260260
}
261261
};
262262

263+
/**
264+
* A template function that determines the default floating-point type
265+
* based on the value of the template parameter has_fp64.
266+
*/
267+
template <typename has_fp64 = std::true_type>
268+
static constexpr DPNPFuncType get_default_floating_type()
269+
{
270+
return has_fp64::value ? DPNPFuncType::DPNP_FT_DOUBLE
271+
: DPNPFuncType::DPNP_FT_FLOAT;
272+
}
273+
263274
/**
264275
* FPTR interface initialization functions
265276
*/

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ cdef extern from "dpnp_iface_fptr.hpp":
336336
struct DPNPFuncData:
337337
DPNPFuncType return_type
338338
void * ptr
339+
DPNPFuncType return_type_no_fp64
340+
void *ptr_no_fp64
339341

340342
DPNPFuncData get_dpnp_function_ptr(DPNPFuncName name, DPNPFuncType first_type, DPNPFuncType second_type) except +
341343

0 commit comments

Comments
 (0)