Skip to content

Commit d81a6a7

Browse files
committed
Implements mean, var, and std
1 parent b1fea28 commit d81a6a7

File tree

3 files changed

+357
-0
lines changed

3 files changed

+357
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
)
9191
from dpctl.tensor._reshape import reshape
9292
from dpctl.tensor._search_functions import where
93+
from dpctl.tensor._statistical_functions import mean, std, var
9394
from dpctl.tensor._usmarray import usm_ndarray
9495
from dpctl.tensor._utility_functions import all, any
9596

@@ -335,4 +336,7 @@
335336
"clip",
336337
"logsumexp",
337338
"reduce_hypot",
339+
"mean",
340+
"std",
341+
"var",
338342
]
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from numpy.core.numeric import normalize_axis_tuple
18+
19+
import dpctl
20+
import dpctl.tensor as dpt
21+
import dpctl.tensor._tensor_elementwise_impl as tei
22+
import dpctl.tensor._tensor_impl as ti
23+
import dpctl.tensor._tensor_reductions_impl as tri
24+
25+
from ._reduction import _default_reduction_dtype
26+
27+
28+
def _var_impl(x, axis, correction, keepdims):
29+
nd = x.ndim
30+
if axis is None:
31+
axis = tuple(range(nd))
32+
if not isinstance(axis, (tuple, list)):
33+
axis = (axis,)
34+
axis = normalize_axis_tuple(axis, nd, "axis")
35+
perm = []
36+
nelems = 1
37+
for i in range(nd):
38+
if i not in axis:
39+
perm.append(i)
40+
else:
41+
nelems *= x.shape[i]
42+
red_nd = len(axis)
43+
perm = perm + list(axis)
44+
q = x.sycl_queue
45+
inp_dt = x.dtype
46+
res_dt = (
47+
inp_dt
48+
if inp_dt.kind == "f"
49+
else dpt.dtype(ti.default_device_fp_type(q))
50+
)
51+
res_usm_type = x.usm_type
52+
53+
deps = []
54+
host_tasks_list = []
55+
if inp_dt != res_dt:
56+
buf = dpt.empty_like(x, dtype=res_dt)
57+
ht_e_buf, c_e1 = ti._copy_usm_ndarray_into_usm_ndarray(
58+
src=x, dst=buf, sycl_queue=q
59+
)
60+
deps.append(c_e1)
61+
host_tasks_list.append(ht_e_buf)
62+
else:
63+
buf = x
64+
# calculate mean
65+
buf2 = dpt.permute_dims(buf, perm)
66+
res_shape = buf2.shape[: nd - red_nd]
67+
# use keepdims=True path for later broadcasting
68+
if red_nd == 0:
69+
mean_ary = dpt.empty_like(buf)
70+
ht_e1, c_e2 = ti._copy_usm_ndarray_into_usm_ndarray(
71+
src=buf, dst=mean_ary, sycl_queue=q
72+
)
73+
deps.append(c_e2)
74+
host_tasks_list.append(ht_e1)
75+
else:
76+
mean_ary = dpt.empty(
77+
res_shape,
78+
dtype=res_dt,
79+
usm_type=res_usm_type,
80+
sycl_queue=q,
81+
)
82+
ht_e1, r_e1 = tri._sum_over_axis(
83+
src=buf2,
84+
trailing_dims_to_reduce=red_nd,
85+
dst=mean_ary,
86+
sycl_queue=q,
87+
depends=deps,
88+
)
89+
host_tasks_list.append(ht_e1)
90+
deps.append(r_e1)
91+
92+
mean_ary_shape = res_shape + (1,) * red_nd
93+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
94+
mean_ary = dpt.permute_dims(
95+
dpt.reshape(mean_ary, mean_ary_shape), inv_perm
96+
)
97+
# divide in-place to get mean
98+
mean_ary_shape = mean_ary.shape
99+
nelems_ary = dpt.asarray(
100+
nelems, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
101+
)
102+
if nelems_ary.shape != mean_ary_shape:
103+
nelems_ary = dpt.broadcast_to(nelems_ary, mean_ary_shape)
104+
ht_e2, d_e1 = tei._divide_inplace(
105+
lhs=mean_ary, rhs=nelems_ary, sycl_queue=q, depends=deps
106+
)
107+
host_tasks_list.append(ht_e2)
108+
# subtract mean from original array to get deviations
109+
dev_ary = dpt.empty_like(buf)
110+
if mean_ary_shape != buf.shape:
111+
mean_ary = dpt.broadcast_to(mean_ary, buf.shape)
112+
ht_e4, su_e = tei._subtract(
113+
src1=buf, src2=mean_ary, dst=dev_ary, sycl_queue=q, depends=[d_e1]
114+
)
115+
host_tasks_list.append(ht_e4)
116+
# square deviations
117+
ht_e5, sq_e = tei._square(
118+
src=dev_ary, dst=dev_ary, sycl_queue=q, depends=[su_e]
119+
)
120+
host_tasks_list.append(ht_e5)
121+
deps2 = []
122+
# take sum of squared deviations
123+
dev_ary2 = dpt.permute_dims(dev_ary, perm)
124+
if red_nd == 0:
125+
res = dev_ary
126+
deps2.append(sq_e)
127+
else:
128+
res = dpt.empty(
129+
res_shape,
130+
dtype=res_dt,
131+
usm_type=res_usm_type,
132+
sycl_queue=q,
133+
)
134+
ht_e6, r_e2 = tri._sum_over_axis(
135+
src=dev_ary2,
136+
trailing_dims_to_reduce=red_nd,
137+
dst=res,
138+
sycl_queue=q,
139+
depends=[sq_e],
140+
)
141+
host_tasks_list.append(ht_e6)
142+
deps2.append(r_e2)
143+
144+
if keepdims:
145+
res_shape = res_shape + (1,) * red_nd
146+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
147+
res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
148+
res_shape = res.shape
149+
# when nelems - correction <= 0, yield nans
150+
div = max(nelems - correction, 0)
151+
if not div:
152+
div = dpt.nan
153+
div_ary = dpt.asarray(div, res_dt, usm_type=res_usm_type, sycl_queue=q)
154+
# divide in-place again
155+
if div_ary.shape != res_shape:
156+
div_ary = dpt.broadcast_to(div_ary, res.shape)
157+
ht_e7, d_e2 = tei._divide_inplace(
158+
lhs=res, rhs=div_ary, sycl_queue=q, depends=deps2
159+
)
160+
host_tasks_list.append(ht_e7)
161+
return res, [d_e2], host_tasks_list
162+
163+
164+
def mean(x, axis=None, keepdims=False):
165+
if not isinstance(x, dpt.usm_ndarray):
166+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
167+
nd = x.ndim
168+
if axis is None:
169+
axis = tuple(range(nd))
170+
if not isinstance(axis, (tuple, list)):
171+
axis = (axis,)
172+
axis = normalize_axis_tuple(axis, nd, "axis")
173+
perm = []
174+
nelems = 1
175+
for i in range(nd):
176+
if i not in axis:
177+
perm.append(i)
178+
else:
179+
nelems *= x.shape[i]
180+
sum_nd = len(axis)
181+
perm = perm + list(axis)
182+
arr2 = dpt.permute_dims(x, perm)
183+
res_shape = arr2.shape[: nd - sum_nd]
184+
q = x.sycl_queue
185+
inp_dt = x.dtype
186+
res_dt = (
187+
x.dtype
188+
if x.dtype.kind in "fc"
189+
else dpt.dtype(ti.default_device_fp_type(q))
190+
)
191+
res_usm_type = x.usm_type
192+
if sum_nd == 0:
193+
return dpt.astype(x, res_dt, copy=True)
194+
195+
s_e = []
196+
host_tasks_list = []
197+
if tri._sum_over_axis_dtype_supported(inp_dt, res_dt, res_usm_type, q):
198+
res = dpt.empty(
199+
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
200+
)
201+
ht_e1, r_e = tri._sum_over_axis(
202+
src=arr2, trailing_dims_to_reduce=sum_nd, dst=res, sycl_queue=q
203+
)
204+
host_tasks_list.append(ht_e1)
205+
s_e.append(r_e)
206+
else:
207+
tmp_dt = _default_reduction_dtype(inp_dt, q)
208+
tmp = dpt.empty(
209+
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
210+
)
211+
ht_e_tmp, r_e = tri._sum_over_axis(
212+
src=arr2, trailing_dims_to_reduce=sum_nd, dst=tmp, sycl_queue=q
213+
)
214+
host_tasks_list.append(ht_e_tmp)
215+
res = dpt.empty(
216+
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
217+
)
218+
ht_e1, c_e = ti._copy_usm_ndarray_into_usm_ndarray(
219+
src=tmp, dst=res, sycl_queue=q, depends=[r_e]
220+
)
221+
host_tasks_list.append(ht_e1)
222+
s_e.append(c_e)
223+
224+
if keepdims:
225+
res_shape = res_shape + (1,) * sum_nd
226+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
227+
res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
228+
229+
res_shape = res.shape
230+
# in-place divide
231+
nelems_arr = dpt.asarray(
232+
nelems, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
233+
)
234+
if nelems_arr.shape != res_shape:
235+
nelems_arr = dpt.broadcast_to(nelems_arr, res_shape)
236+
ht_e2, _ = tei._divide_inplace(
237+
lhs=res, rhs=nelems_arr, sycl_queue=q, depends=s_e
238+
)
239+
host_tasks_list.append(ht_e2)
240+
dpctl.SyclEvent.wait_for(host_tasks_list)
241+
return res
242+
243+
244+
def var(x, axis=None, correction=0.0, keepdims=False):
245+
if not isinstance(x, dpt.usm_ndarray):
246+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
247+
248+
if not isinstance(correction, (int, float)):
249+
raise TypeError(
250+
"Expected a Python integer or float for `correction`, got"
251+
f"{type(x)}"
252+
)
253+
254+
if x.dtype.kind == "c":
255+
raise ValueError("`var` does not support complex types")
256+
257+
res, _, host_tasks_list = _var_impl(x, axis, correction, keepdims)
258+
dpctl.SyclEvent.wait_for(host_tasks_list)
259+
return res
260+
261+
262+
def std(x, axis=None, correction=0.0, keepdims=False):
263+
if not isinstance(x, dpt.usm_ndarray):
264+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
265+
266+
if not isinstance(correction, (int, float)):
267+
raise TypeError(
268+
"Expected a Python integer or float for `correction`,"
269+
f"got {type(x)}"
270+
)
271+
272+
if x.dtype.kind == "c":
273+
raise ValueError("`std` does not support complex types")
274+
275+
res, deps, host_tasks_list = _var_impl(x, axis, correction, keepdims)
276+
ht_ev, _ = tei._sqrt(
277+
src=res, dst=res, sycl_queue=res.sycl_queue, depends=deps
278+
)
279+
host_tasks_list.append(ht_ev)
280+
dpctl.SyclEvent.wait_for(host_tasks_list)
281+
return res
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tensor._tensor_impl import default_device_fp_type
21+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
22+
23+
_no_complex_dtypes = [
24+
"?",
25+
"i1",
26+
"u1",
27+
"i2",
28+
"u2",
29+
"i4",
30+
"u4",
31+
"i8",
32+
"u8",
33+
"f2",
34+
"f4",
35+
"f8",
36+
]
37+
38+
39+
@pytest.mark.parametrize("dt", _no_complex_dtypes)
40+
def test_mean_dtypes(dt):
41+
q = get_queue_or_skip()
42+
skip_if_dtype_not_supported(dt, q)
43+
44+
x = dpt.ones(10, dtype=dt)
45+
res = dpt.mean(x)
46+
assert res == 1
47+
if x.dtype.kind in "biu":
48+
assert res.dtype == dpt.dtype(default_device_fp_type(q))
49+
else:
50+
assert res.dtype == x.dtype
51+
52+
53+
@pytest.mark.parametrize("dt", _no_complex_dtypes)
54+
@pytest.mark.parametrize("py_zero", [float(0), int(0)])
55+
def test_std_var_dtypes(dt, py_zero):
56+
q = get_queue_or_skip()
57+
skip_if_dtype_not_supported(dt, q)
58+
59+
x = dpt.ones(10, dtype=dt)
60+
res = dpt.std(x, correction=py_zero)
61+
assert res == 0
62+
if x.dtype.kind in "biu":
63+
assert res.dtype == dpt.dtype(default_device_fp_type(q))
64+
else:
65+
assert res.dtype == x.dtype
66+
67+
res = dpt.var(x, correction=py_zero)
68+
assert res == 0
69+
if x.dtype.kind in "biu":
70+
assert res.dtype == dpt.dtype(default_device_fp_type(q))
71+
else:
72+
assert res.dtype == x.dtype

0 commit comments

Comments
 (0)