|
41 | 41 |
|
42 | 42 |
|
43 | 43 | import numpy
|
44 |
| - |
| 44 | +import dpctl.tensor as dpt |
45 | 45 | from dpnp.dpnp_algo import *
|
46 | 46 | from dpnp.dpnp_utils import *
|
47 | 47 | from dpnp.dpnp_array import dpnp_array
|
@@ -274,31 +274,30 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
|
274 | 274 | [1.0, -1.0, -1.0, 1.0]
|
275 | 275 |
|
276 | 276 | """
|
277 |
| - |
278 |
| - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) |
279 |
| - if x1_desc: |
280 |
| - if x1_desc.ndim > 2: |
281 |
| - pass |
282 |
| - elif y is not None: |
283 |
| - pass |
284 |
| - elif bias: |
285 |
| - pass |
286 |
| - elif ddof is not None: |
287 |
| - pass |
288 |
| - elif fweights is not None: |
289 |
| - pass |
290 |
| - elif aweights is not None: |
291 |
| - pass |
292 |
| - else: |
293 |
| - if not rowvar and x1.shape[0] != 1: |
294 |
| - x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1 |
295 |
| - x1 = dpnp_array._create_from_usm_ndarray(x1.mT) |
296 |
| - x1 = dpnp.astype(x1, dpnp.float64) if x1_desc.dtype != dpnp.float64 else x1 |
297 |
| - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) |
298 |
| - elif x1_desc.dtype != dpnp.float64: |
299 |
| - x1 = dpnp.astype(x1, dpnp.float64) |
300 |
| - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) |
301 |
| - |
| 277 | + if not isinstance(x1, (dpnp_array, dpt.usm_ndarray)): |
| 278 | + pass |
| 279 | + elif x1.ndim > 2: |
| 280 | + pass |
| 281 | + elif y is not None: |
| 282 | + pass |
| 283 | + elif bias: |
| 284 | + pass |
| 285 | + elif ddof is not None: |
| 286 | + pass |
| 287 | + elif fweights is not None: |
| 288 | + pass |
| 289 | + elif aweights is not None: |
| 290 | + pass |
| 291 | + else: |
| 292 | + if not rowvar and x1.shape[0] != 1: |
| 293 | + x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1 |
| 294 | + x1 = dpnp_array._create_from_usm_ndarray(x1.mT) |
| 295 | + |
| 296 | + if not x1.dtype in (dpnp.float32, dpnp.float64): |
| 297 | + x1 = dpnp.astype(x1, dpnp.default_float_type(sycl_queue=x1.sycl_queue)) |
| 298 | + |
| 299 | + x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) |
| 300 | + if x1_desc: |
302 | 301 | return dpnp_cov(x1_desc).get_pyobj()
|
303 | 302 |
|
304 | 303 | return call_origin(numpy.cov, x1, y, rowvar, bias, ddof, fweights, aweights)
|
|
0 commit comments