Skip to content

Commit 2dfa804

Browse files
vtavanaantonwolfy
authored andcommitted
using rowvar flag in dpnp.cov
1 parent 12e791e commit 2dfa804

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

dpnp/dpnp_iface_statistics.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from dpnp.dpnp_algo import *
4646
from dpnp.dpnp_utils import *
47+
from dpnp.dpnp_array import dpnp_array
4748
import dpnp
4849

4950

@@ -237,7 +238,8 @@ def correlate(x1, x2, mode='valid'):
237238

238239

239240
def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None):
240-
"""
241+
"""cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None):
242+
241243
Estimate a covariance matrix, given data and weights.
242244
243245
For full documentation refer to :obj:`numpy.cov`.
@@ -248,7 +250,6 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
248250
Dimension of input array ``m`` is limited by ``m.ndim > 2``.
249251
Size and shape of input arrays are supported to be equal.
250252
Prameters ``y`` is supported only with default value ``None``.
251-
Prameters ``rowvar`` is supported only with default value ``True``.
252253
Prameters ``bias`` is supported only with default value ``False``.
253254
Prameters ``ddof`` is supported only with default value ``None``.
254255
Prameters ``fweights`` is supported only with default value ``None``.
@@ -280,8 +281,6 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
280281
pass
281282
elif y is not None:
282283
pass
283-
elif not rowvar:
284-
pass
285284
elif bias:
286285
pass
287286
elif ddof is not None:
@@ -291,8 +290,14 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
291290
elif aweights is not None:
292291
pass
293292
else:
294-
if x1_desc.dtype != dpnp.float64:
295-
x1_desc = dpnp.get_dpnp_descriptor(dpnp.astype(x1, dpnp.float64), copy_when_nondefault_queue=False)
293+
if not rowvar and x1.shape[0] != 1:
294+
x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1
295+
x1 = dpnp_array._create_from_usm_ndarray(x1.mT)
296+
x1 = dpnp.astype(x1, dpnp.float64) if x1_desc.dtype != dpnp.float64 else x1
297+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
298+
elif x1_desc.dtype != dpnp.float64:
299+
x1 = dpnp.astype(x1, dpnp.float64)
300+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
296301

297302
return dpnp_cov(x1_desc).get_pyobj()
298303

tests/test_statistics.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
2+
from .helper import get_all_dtypes
33
import dpnp
44

55
import numpy
@@ -114,3 +114,18 @@ def test_bincount_weights(self, array, weights):
114114
expected = numpy.bincount(np_a, weights=weights)
115115
result = dpnp.bincount(dpnp_a, weights=weights)
116116
numpy.testing.assert_array_equal(expected, result)
117+
118+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True))
119+
def test_cov_rowvar1(dtype):
120+
a = dpnp.array([[0, 2], [1, 1], [2, 0]], dtype=dtype)
121+
b = numpy.array([[0, 2], [1, 1], [2, 0]], dtype=dtype)
122+
numpy.testing.assert_array_equal(dpnp.cov(a.T), dpnp.cov(a,rowvar=False))
123+
numpy.testing.assert_array_equal(numpy.cov(b,rowvar=False), dpnp.cov(a,rowvar=False))
124+
125+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True))
126+
def test_cov_rowvar2(dtype):
127+
a = dpnp.array([[0, 1, 2]], dtype=dtype)
128+
b = numpy.array([[0, 1, 2]], dtype=dtype)
129+
numpy.testing.assert_array_equal(numpy.cov(b,rowvar=False), dpnp.cov(a,rowvar=False))
130+
131+

0 commit comments

Comments
 (0)