Skip to content

Commit b67bcea

Browse files
committed
Adds more tests for statistical functions
1 parent d81a6a7 commit b67bcea

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

dpctl/tests/test_tensor_statistical_functions.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,185 @@ def test_std_var_dtypes(dt, py_zero):
7070
assert res.dtype == dpt.dtype(default_device_fp_type(q))
7171
else:
7272
assert res.dtype == x.dtype
73+
74+
75+
def test_stat_fns_axis():
76+
get_queue_or_skip()
77+
78+
x = dpt.ones((3, 4, 5, 6, 7), dtype="f4")
79+
m = dpt.mean(x, axis=(1, 2, -1))
80+
81+
assert isinstance(m, dpt.usm_ndarray)
82+
assert m.shape == (3, 6)
83+
assert dpt.allclose(m, dpt.asarray(1, dtype=m.dtype))
84+
85+
s = dpt.var(x, axis=(1, 2, -1))
86+
assert isinstance(s, dpt.usm_ndarray)
87+
assert s.shape == (3, 6)
88+
assert dpt.allclose(s, dpt.asarray(0, dtype=s.dtype))
89+
90+
91+
@pytest.mark.parametrize("fn", [dpt.mean, dpt.var])
92+
def test_stat_fns_empty(fn):
93+
get_queue_or_skip()
94+
x = dpt.empty((0,), dtype="f4")
95+
r = fn(x)
96+
assert r.shape == tuple()
97+
assert dpt.isnan(r)
98+
99+
x = dpt.empty((10, 0, 2), dtype="f4")
100+
r = fn(x, axis=1)
101+
assert r.shape == (10, 2)
102+
assert dpt.all(dpt.isnan(r))
103+
104+
r = fn(x, axis=0)
105+
assert r.shape == (0, 2)
106+
assert r.size == 0
107+
108+
109+
def test_stat_fns_keepdims():
110+
get_queue_or_skip()
111+
112+
x = dpt.ones((3, 4, 5, 6, 7), dtype="f4")
113+
m = dpt.mean(x, axis=(1, 2, -1), keepdims=True)
114+
115+
assert isinstance(m, dpt.usm_ndarray)
116+
assert m.shape == (3, 1, 1, 6, 1)
117+
assert dpt.allclose(m, dpt.asarray(1, dtype=m.dtype))
118+
119+
s = dpt.var(x, axis=(1, 2, -1), keepdims=True)
120+
assert isinstance(s, dpt.usm_ndarray)
121+
assert s.shape == (3, 1, 1, 6, 1)
122+
assert dpt.allclose(s, dpt.asarray(0, dtype=s.dtype))
123+
124+
125+
def test_stat_fns_empty_axis():
126+
get_queue_or_skip()
127+
128+
x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5))
129+
m = dpt.mean(x, axis=())
130+
131+
assert x.shape == m.shape
132+
assert dpt.all(x == m)
133+
134+
s = dpt.var(x, axis=())
135+
assert x.shape == s.shape
136+
assert dpt.all(s == 0)
137+
138+
d = dpt.std(x, axis=())
139+
assert x.shape == d.shape
140+
assert dpt.all(d == 0)
141+
142+
143+
def test_mean():
144+
get_queue_or_skip()
145+
146+
x = dpt.reshape(dpt.arange(9, dtype="f4"), (3, 3))
147+
m = dpt.mean(x)
148+
expected = dpt.asarray(4, dtype="f4")
149+
assert dpt.allclose(m, expected)
150+
151+
m = dpt.mean(x, axis=0)
152+
expected = dpt.arange(3, 6, dtype="f4")
153+
assert dpt.allclose(m, expected)
154+
155+
m = dpt.mean(x, axis=1)
156+
expected = dpt.asarray([1, 4, 7], dtype="f4")
157+
assert dpt.allclose(m, expected)
158+
159+
160+
def test_var_std():
161+
get_queue_or_skip()
162+
163+
x = dpt.reshape(dpt.arange(9, dtype="f4"), (3, 3))
164+
r = dpt.var(x)
165+
expected = dpt.asarray(6.666666507720947, dtype="f4")
166+
assert dpt.allclose(r, expected)
167+
168+
r1 = dpt.var(x, correction=3)
169+
expected1 = dpt.asarray(10.0, dtype="f4")
170+
assert dpt.allclose(r1, expected1)
171+
172+
r = dpt.std(x)
173+
expected = dpt.sqrt(expected)
174+
assert dpt.allclose(r, expected)
175+
176+
r1 = dpt.std(x, correction=3)
177+
expected1 = dpt.sqrt(expected1)
178+
assert dpt.allclose(r1, expected1)
179+
180+
r = dpt.var(x, axis=0)
181+
expected = dpt.full(x.shape[1], 6, dtype="f4")
182+
assert dpt.allclose(r, expected)
183+
184+
r1 = dpt.var(x, axis=0, correction=1)
185+
expected1 = dpt.full(x.shape[1], 9, dtype="f4")
186+
assert dpt.allclose(r1, expected1)
187+
188+
r = dpt.std(x, axis=0)
189+
expected = dpt.sqrt(expected)
190+
assert dpt.allclose(r, expected)
191+
192+
r1 = dpt.std(x, axis=0, correction=1)
193+
expected1 = dpt.sqrt(expected1)
194+
assert dpt.allclose(r1, expected1)
195+
196+
r = dpt.var(x, axis=1)
197+
expected = dpt.full(x.shape[0], 0.6666666865348816, dtype="f4")
198+
assert dpt.allclose(r, expected)
199+
200+
r1 = dpt.var(x, axis=1, correction=1)
201+
expected1 = dpt.ones(x.shape[0], dtype="f4")
202+
assert dpt.allclose(r1, expected1)
203+
204+
r = dpt.std(x, axis=1)
205+
expected = dpt.sqrt(expected)
206+
assert dpt.allclose(r, expected)
207+
208+
r1 = dpt.std(x, axis=1, correction=1)
209+
expected1 = dpt.sqrt(expected1)
210+
assert dpt.allclose(r1, expected1)
211+
212+
213+
def test_var_axis_length_correction():
214+
get_queue_or_skip()
215+
216+
x = dpt.reshape(dpt.arange(9, dtype="f4"), (3, 3))
217+
218+
r = dpt.var(x, correction=x.size)
219+
assert dpt.isnan(r)
220+
221+
r = dpt.var(x, axis=0, correction=x.shape[0])
222+
assert dpt.all(dpt.isnan(r))
223+
224+
r = dpt.var(x, axis=1, correction=x.shape[1])
225+
assert dpt.all(dpt.isnan(r))
226+
227+
228+
def test_stat_function_errors():
229+
d = dict()
230+
with pytest.raises(TypeError):
231+
dpt.var(d)
232+
with pytest.raises(TypeError):
233+
dpt.std(d)
234+
with pytest.raises(TypeError):
235+
dpt.mean(d)
236+
237+
x = dpt.empty(1, dtype="f4")
238+
with pytest.raises(TypeError):
239+
dpt.var(x, axis=d)
240+
with pytest.raises(TypeError):
241+
dpt.std(x, axis=d)
242+
with pytest.raises(TypeError):
243+
dpt.mean(x, axis=d)
244+
245+
with pytest.raises(TypeError):
246+
dpt.var(x, correction=d)
247+
with pytest.raises(TypeError):
248+
dpt.std(x, correction=d)
249+
250+
x = dpt.empty(1, dtype="c8")
251+
with pytest.raises(ValueError):
252+
dpt.var(x)
253+
with pytest.raises(ValueError):
254+
dpt.std(x)

0 commit comments

Comments
 (0)