Skip to content

Commit 5b7ba6d

Browse files
committed
Implement dpnp.cov() though existing dpnp methods
1 parent 4d27b4c commit 5b7ba6d

File tree

5 files changed

+250
-63
lines changed

5 files changed

+250
-63
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
import dpctl.tensor as dpt
4545
from dpnp.dpnp_algo import *
4646
from dpnp.dpnp_utils import *
47+
from dpnp.dpnp_utils.dpnp_utils_statistics import (
48+
dpnp_cov
49+
)
4750
from dpnp.dpnp_array import dpnp_array
4851
import dpnp
4952

@@ -237,13 +240,18 @@ def correlate(x1, x2, mode='valid'):
237240
return call_origin(numpy.correlate, x1, x2, mode=mode)
238241

239242

240-
def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None):
241-
"""cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None):
243+
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, *, dtype=None):
244+
"""cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, *, dtype=None):
242245
243246
Estimate a covariance matrix, given data and weights.
244247
245248
For full documentation refer to :obj:`numpy.cov`.
246249
250+
Returns
251+
-------
252+
out : dpnp.ndarray
253+
The covariance matrix of the variables.
254+
247255
Limitations
248256
-----------
249257
Input array ``m`` is supported as :obj:`dpnp.ndarray`.
@@ -257,7 +265,9 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
257265
Otherwise the function will be executed sequentially on CPU.
258266
Input array data types are limited by supported DPNP :ref:`Data types`.
259267
260-
.. see also:: :obj:`dpnp.corrcoef` normalized covariance matrix.
268+
See Also
269+
--------
270+
:obj:`dpnp.corrcoef` : Normalized covariance matrix
261271
262272
Examples
263273
--------
@@ -274,11 +284,10 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
274284
[1.0, -1.0, -1.0, 1.0]
275285
276286
"""
277-
if not isinstance(x1, (dpnp_array, dpt.usm_ndarray)):
278-
pass
279-
elif x1.ndim > 2:
287+
288+
if not isinstance(m, (dpnp_array, dpt.usm_ndarray)):
280289
pass
281-
elif y is not None:
290+
elif m.ndim > 2:
282291
pass
283292
elif bias:
284293
pass
@@ -289,18 +298,13 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=
289298
elif aweights is not None:
290299
pass
291300
else:
292-
if not rowvar and x1.shape[0] != 1:
293-
x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1
294-
x1 = dpnp_array._create_from_usm_ndarray(x1.mT)
295-
296-
if not x1.dtype in (dpnp.float32, dpnp.float64):
297-
x1 = dpnp.astype(x1, dpnp.default_float_type(sycl_queue=x1.sycl_queue))
301+
return dpnp_cov(m, y=y, rowvar=rowvar, dtype=dtype)
298302

299-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
300-
if x1_desc:
301-
return dpnp_cov(x1_desc).get_pyobj()
303+
# x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
304+
# if x1_desc:
305+
# return dpnp_cov(x1_desc).get_pyobj()
302306

303-
return call_origin(numpy.cov, x1, y, rowvar, bias, ddof, fweights, aweights)
307+
return call_origin(numpy.cov, m, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype)
304308

305309

306310
def histogram(a, bins=10, range=None, density=None, weights=None):
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# cython: language_level=3
2+
# distutils: language = c++
3+
# -*- coding: utf-8 -*-
4+
# *****************************************************************************
5+
# Copyright (c) 2023, Intel Corporation
6+
# All rights reserved.
7+
#
8+
# Redistribution and use in source and binary forms, with or without
9+
# modification, are permitted provided that the following conditions are met:
10+
# - Redistributions of source code must retain the above copyright notice,
11+
# this list of conditions and the following disclaimer.
12+
# - Redistributions in binary form must reproduce the above copyright notice,
13+
# this list of conditions and the following disclaimer in the documentation
14+
# and/or other materials provided with the distribution.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
30+
import dpnp
31+
from dpnp.dpnp_array import dpnp_array
32+
from dpnp.dpnp_utils import (
33+
get_usm_allocations
34+
)
35+
36+
import dpctl
37+
import dpctl.tensor as dpt
38+
import dpctl.tensor._tensor_impl as ti
39+
40+
# TODO: replace with calls from dpctl module
41+
import unary_fns
42+
43+
44+
__all__ = [
45+
"dpnp_cov"
46+
]
47+
48+
def dpnp_cov(m, y=None, rowvar=True, dtype=None):
49+
"""
50+
Estimate a covariance matrix based on passed data.
51+
No support for given wights is provided now.
52+
53+
The implementation is done though existing dpnp and dpctl methods
54+
instead of separate function call of dnpn backend.
55+
56+
"""
57+
58+
def _get_2dmin_array(x, dtype):
59+
"""
60+
Transfor an input array to a form required for building a covariance matrix.
61+
62+
If applicable, it resahpes the imput array to have 2 dimensions or greater.
63+
If applicable, it transposes the imput array when 'rowvar' is False.
64+
It casts to another dtype, if the input array differs from requested one.
65+
66+
"""
67+
68+
if x.ndim == 0:
69+
x = x.reshape((1, 1))
70+
elif m.ndim == 1:
71+
x = x[dpnp.newaxis, :]
72+
73+
if not rowvar and x.shape[0] != 1:
74+
# TODO: replace once ready with
75+
# x = x.T
76+
x = dpnp_array._create_from_usm_ndarray(x.get_array().T)
77+
78+
if x.dtype != dtype:
79+
x = dpnp.astype(x, dtype)
80+
return x
81+
82+
83+
# input arrays must follow CFD paradigm
84+
usm_type, queue = get_usm_allocations((m, ) if y is None else (m, y))
85+
86+
# calculate a type of result array if not passed explicitly
87+
if dtype is None:
88+
dtypes = [m.dtype, dpnp.default_float_type(sycl_queue=queue)]
89+
if y is not None:
90+
dtypes.append(y.dtype)
91+
dtype = dpt.result_type(*dtypes)
92+
93+
X = _get_2dmin_array(m, dtype)
94+
if y is not None:
95+
y = _get_2dmin_array(y, dtype)
96+
97+
# TODO: replace with dpnp.concatenate((X, y), axis=0) once dpctl implementation is ready
98+
if X.ndim != y.ndim:
99+
raise ValueError("all the input arrays must have same number of dimensions")
100+
101+
if X.shape[1:] != y.shape[1:]:
102+
raise ValueError("all the input array dimensions for the concatenation axis must match exactly")
103+
104+
res_shape = tuple(X.shape[i] if i > 0 else (X.shape[i] + y.shape[i]) for i in range(X.ndim))
105+
res_usm = dpt.empty(res_shape, dtype=dtype, usm_type=usm_type, sycl_queue=queue)
106+
107+
# concatenate input arrays 'm' and 'y' into single array among 0-axis
108+
hev1, _ = ti._copy_usm_ndarray_into_usm_ndarray(src=X.get_array(), dst=res_usm[:X.shape[0]], sycl_queue=queue)
109+
hev2, _ = ti._copy_usm_ndarray_into_usm_ndarray(src=y.get_array(), dst=res_usm[X.shape[0]:], sycl_queue=queue)
110+
dpctl.SyclEvent.wait_for([hev1, hev2])
111+
112+
X = dpnp_array._create_from_usm_ndarray(res_usm)
113+
114+
# TODO: replace once ready with
115+
# avg = X.mean(axis=1)
116+
# avg = X.sum(axis=1) / X.shape[1]
117+
avg = unary_fns.sum(X.get_array(), axis=1) / X.shape[1]
118+
119+
fact = X.shape[1] - 1
120+
X -= avg[:, None]
121+
122+
# TODO: replace once ready with
123+
# c = dpnp.dot(X, X.T.conj())
124+
c = dpnp.dot(X, dpnp_array._create_from_usm_ndarray(X.get_array().T).conj())
125+
c *= 1 / fact if fact != 0 else dpnp.nan
126+
127+
# TODO: replace with dpnp.squeeze(c) once ready
128+
usm_c = dpnp.get_usm_ndarray(c)
129+
usm_c = dpt.squeeze(usm_c)
130+
return dpnp_array._create_from_usm_ndarray(usm_c)

tests/skipped_tests_gpu.tbl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{extern
271271
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_1_{external=False, length=20000}::test_partition_invalid_negative_kth
272272
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_axis
273273
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{external=True, length=10}::test_partition_negative_axis
274-
tests/third_party/cupy/statistics_tests/test_correlation.py::TestCov::test_cov_empty
274+
275275
tests/third_party/cupy/statistics_tests/test_meanvar.py::TestMeanVar::test_external_mean_axis
276276

277277
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_axis

tests/third_party/cupy/statistics_tests/test_correlation.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import sys
12
import unittest
23

34
import numpy
45
import pytest
56

67
import dpnp as cupy
8+
from dpctl import select_default_device
79
from tests.third_party.cupy import testing
810

911

@@ -37,9 +39,11 @@ def test_corrcoef_rowvar(self, xp, dtype):
3739
return xp.corrcoef(a, y=y, rowvar=False)
3840

3941

40-
@testing.gpu
4142
class TestCov(unittest.TestCase):
4243

44+
# resulting dtype will differ with numpy if no fp64 support by a default device
45+
_has_fp64 = select_default_device().has_aspect_fp64
46+
4347
def generate_input(self, a_shape, y_shape, xp, dtype):
4448
a = testing.shaped_arange(a_shape, xp, dtype)
4549
y = None
@@ -48,27 +52,40 @@ def generate_input(self, a_shape, y_shape, xp, dtype):
4852
return a, y
4953

5054
@testing.for_all_dtypes()
51-
@testing.numpy_cupy_allclose(type_check=False)
55+
@testing.numpy_cupy_allclose(type_check=_has_fp64, accept_error=True)
5256
def check(self, a_shape, y_shape=None, rowvar=True, bias=False,
53-
ddof=None, xp=None, dtype=None):
57+
ddof=None, xp=None, dtype=None,
58+
fweights=None, aweights=None, name=None):
5459
a, y = self.generate_input(a_shape, y_shape, xp, dtype)
55-
return xp.cov(a, y, rowvar, bias, ddof)
60+
if fweights is not None:
61+
fweights = name.asarray(fweights)
62+
if aweights is not None:
63+
aweights = name.asarray(aweights)
64+
# print(type(fweights))
65+
# return xp.cov(a, y, rowvar, bias, ddof,
66+
# fweights, aweights, dtype=dtype)
67+
return xp.cov(a, y, rowvar, bias, ddof,
68+
fweights, aweights)
5669

5770
@testing.for_all_dtypes()
58-
@testing.numpy_cupy_allclose()
71+
@testing.numpy_cupy_allclose(accept_error=True)
5972
def check_warns(self, a_shape, y_shape=None, rowvar=True, bias=False,
60-
ddof=None, xp=None, dtype=None):
73+
ddof=None, xp=None, dtype=None,
74+
fweights=None, aweights=None):
6175
with testing.assert_warns(RuntimeWarning):
6276
a, y = self.generate_input(a_shape, y_shape, xp, dtype)
63-
return xp.cov(a, y, rowvar, bias, ddof)
77+
return xp.cov(a, y, rowvar, bias, ddof,
78+
fweights, aweights, dtype=dtype)
6479

6580
@testing.for_all_dtypes()
66-
def check_raises(self, a_shape, y_shape=None, rowvar=True, bias=False,
67-
ddof=None, dtype=None):
81+
def check_raises(self, a_shape, y_shape=None,
82+
rowvar=True, bias=False, ddof=None,
83+
dtype=None, fweights=None, aweights=None):
6884
for xp in (numpy, cupy):
6985
a, y = self.generate_input(a_shape, y_shape, xp, dtype)
7086
with pytest.raises(ValueError):
71-
xp.cov(a, y, rowvar, bias, ddof)
87+
xp.cov(a, y, rowvar, bias, ddof,
88+
fweights, aweights, dtype=dtype)
7289

7390
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
7491
def test_cov(self):
@@ -78,6 +95,12 @@ def test_cov(self):
7895
self.check((2, 3), (2, 3), rowvar=False)
7996
self.check((2, 3), bias=True)
8097
self.check((2, 3), ddof=2)
98+
self.check((2, 3))
99+
self.check((1, 3), fweights=(1, 4, 1))
100+
self.check((1, 3), aweights=(1.0, 4.0, 1.0))
101+
self.check((1, 3), bias=True, aweights=(1.0, 4.0, 1.0))
102+
self.check((1, 3), fweights=(1, 4, 1),
103+
aweights=(1.0, 4.0, 1.0))
81104

82105
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
83106
def test_cov_warns(self):

0 commit comments

Comments
 (0)