Skip to content

Commit a18c1db

Browse files
committed
fix_cov_for_no_fp64
1 parent 2dfa804 commit a18c1db

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

dpnp/backend/kernels/dpnp_krnl_statistics.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ DPCTLSyclEventRef dpnp_cov_c(DPCTLSyclQueueRef q_ref,
192192
nrows, // std::int64_t n,
193193
ncols, // std::int64_t k,
194194
alpha, // T alpha,
195-
temp, //const T* a,
195+
temp, // const T* a,
196196
ncols, // std::int64_t lda,
197197
beta, // T beta,
198198
result, // T* c,
@@ -1379,12 +1379,12 @@ void func_map_init_statistics(func_map_t& fmap)
13791379

13801380
fmap[DPNPFuncName::DPNP_FN_COV][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_cov_default_c<double>};
13811381
fmap[DPNPFuncName::DPNP_FN_COV][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_cov_default_c<double>};
1382-
fmap[DPNPFuncName::DPNP_FN_COV][eft_FLT][eft_FLT] = {eft_DBL, (void*)dpnp_cov_default_c<double>};
1382+
fmap[DPNPFuncName::DPNP_FN_COV][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_cov_default_c<float>};
13831383
fmap[DPNPFuncName::DPNP_FN_COV][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_cov_default_c<double>};
13841384

13851385
fmap[DPNPFuncName::DPNP_FN_COV_EXT][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_cov_ext_c<double>};
13861386
fmap[DPNPFuncName::DPNP_FN_COV_EXT][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_cov_ext_c<double>};
1387-
fmap[DPNPFuncName::DPNP_FN_COV_EXT][eft_FLT][eft_FLT] = {eft_DBL, (void*)dpnp_cov_ext_c<double>};
1387+
fmap[DPNPFuncName::DPNP_FN_COV_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_cov_ext_c<float>};
13881388
fmap[DPNPFuncName::DPNP_FN_COV_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_cov_ext_c<double>};
13891389

13901390
fmap[DPNPFuncName::DPNP_FN_MAX][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_max_default_c<int32_t>};

dpnp/dpnp_iface_statistics.py

+25-26
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
import numpy
44-
44+
import dpctl.tensor as dpt
4545
from dpnp.dpnp_algo import *
4646
from dpnp.dpnp_utils import *
4747
from dpnp.dpnp_array import dpnp_array
@@ -274,31 +274,30 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
274274
[1.0, -1.0, -1.0, 1.0]
275275
276276
"""
277-
278-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
279-
if x1_desc:
280-
if x1_desc.ndim > 2:
281-
pass
282-
elif y is not None:
283-
pass
284-
elif bias:
285-
pass
286-
elif ddof is not None:
287-
pass
288-
elif fweights is not None:
289-
pass
290-
elif aweights is not None:
291-
pass
292-
else:
293-
if not rowvar and x1.shape[0] != 1:
294-
x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1
295-
x1 = dpnp_array._create_from_usm_ndarray(x1.mT)
296-
x1 = dpnp.astype(x1, dpnp.float64) if x1_desc.dtype != dpnp.float64 else x1
297-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
298-
elif x1_desc.dtype != dpnp.float64:
299-
x1 = dpnp.astype(x1, dpnp.float64)
300-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
301-
277+
if not isinstance(x1, (dpnp_array, dpt.usm_ndarray)):
278+
pass
279+
elif x1.ndim > 2:
280+
pass
281+
elif y is not None:
282+
pass
283+
elif bias:
284+
pass
285+
elif ddof is not None:
286+
pass
287+
elif fweights is not None:
288+
pass
289+
elif aweights is not None:
290+
pass
291+
else:
292+
if not rowvar and x1.shape[0] != 1:
293+
x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1
294+
x1 = dpnp_array._create_from_usm_ndarray(x1.mT)
295+
296+
if not x1.dtype in (dpnp.float32, dpnp.float64):
297+
x1 = dpnp.astype(x1, dpnp.default_float_type(sycl_queue=x1.sycl_queue))
298+
299+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
300+
if x1_desc:
302301
return dpnp_cov(x1_desc).get_pyobj()
303302

304303
return call_origin(numpy.cov, x1, y, rowvar, bias, ddof, fweights, aweights)

0 commit comments

Comments
 (0)