@@ -70,3 +70,185 @@ def test_std_var_dtypes(dt, py_zero):
70
70
assert res .dtype == dpt .dtype (default_device_fp_type (q ))
71
71
else :
72
72
assert res .dtype == x .dtype
73
+
74
+
75
+ def test_stat_fns_axis ():
76
+ get_queue_or_skip ()
77
+
78
+ x = dpt .ones ((3 , 4 , 5 , 6 , 7 ), dtype = "f4" )
79
+ m = dpt .mean (x , axis = (1 , 2 , - 1 ))
80
+
81
+ assert isinstance (m , dpt .usm_ndarray )
82
+ assert m .shape == (3 , 6 )
83
+ assert dpt .allclose (m , dpt .asarray (1 , dtype = m .dtype ))
84
+
85
+ s = dpt .var (x , axis = (1 , 2 , - 1 ))
86
+ assert isinstance (s , dpt .usm_ndarray )
87
+ assert s .shape == (3 , 6 )
88
+ assert dpt .allclose (s , dpt .asarray (0 , dtype = s .dtype ))
89
+
90
+
91
+ @pytest .mark .parametrize ("fn" , [dpt .mean , dpt .var ])
92
+ def test_stat_fns_empty (fn ):
93
+ get_queue_or_skip ()
94
+ x = dpt .empty ((0 ,), dtype = "f4" )
95
+ r = fn (x )
96
+ assert r .shape == tuple ()
97
+ assert dpt .isnan (r )
98
+
99
+ x = dpt .empty ((10 , 0 , 2 ), dtype = "f4" )
100
+ r = fn (x , axis = 1 )
101
+ assert r .shape == (10 , 2 )
102
+ assert dpt .all (dpt .isnan (r ))
103
+
104
+ r = fn (x , axis = 0 )
105
+ assert r .shape == (0 , 2 )
106
+ assert r .size == 0
107
+
108
+
109
+ def test_stat_fns_keepdims ():
110
+ get_queue_or_skip ()
111
+
112
+ x = dpt .ones ((3 , 4 , 5 , 6 , 7 ), dtype = "f4" )
113
+ m = dpt .mean (x , axis = (1 , 2 , - 1 ), keepdims = True )
114
+
115
+ assert isinstance (m , dpt .usm_ndarray )
116
+ assert m .shape == (3 , 1 , 1 , 6 , 1 )
117
+ assert dpt .allclose (m , dpt .asarray (1 , dtype = m .dtype ))
118
+
119
+ s = dpt .var (x , axis = (1 , 2 , - 1 ), keepdims = True )
120
+ assert isinstance (s , dpt .usm_ndarray )
121
+ assert s .shape == (3 , 1 , 1 , 6 , 1 )
122
+ assert dpt .allclose (s , dpt .asarray (0 , dtype = s .dtype ))
123
+
124
+
125
+ def test_stat_fns_empty_axis ():
126
+ get_queue_or_skip ()
127
+
128
+ x = dpt .reshape (dpt .arange (3 * 4 * 5 , dtype = "f4" ), (3 , 4 , 5 ))
129
+ m = dpt .mean (x , axis = ())
130
+
131
+ assert x .shape == m .shape
132
+ assert dpt .all (x == m )
133
+
134
+ s = dpt .var (x , axis = ())
135
+ assert x .shape == s .shape
136
+ assert dpt .all (s == 0 )
137
+
138
+ d = dpt .std (x , axis = ())
139
+ assert x .shape == d .shape
140
+ assert dpt .all (d == 0 )
141
+
142
+
143
+ def test_mean ():
144
+ get_queue_or_skip ()
145
+
146
+ x = dpt .reshape (dpt .arange (9 , dtype = "f4" ), (3 , 3 ))
147
+ m = dpt .mean (x )
148
+ expected = dpt .asarray (4 , dtype = "f4" )
149
+ assert dpt .allclose (m , expected )
150
+
151
+ m = dpt .mean (x , axis = 0 )
152
+ expected = dpt .arange (3 , 6 , dtype = "f4" )
153
+ assert dpt .allclose (m , expected )
154
+
155
+ m = dpt .mean (x , axis = 1 )
156
+ expected = dpt .asarray ([1 , 4 , 7 ], dtype = "f4" )
157
+ assert dpt .allclose (m , expected )
158
+
159
+
160
+ def test_var_std ():
161
+ get_queue_or_skip ()
162
+
163
+ x = dpt .reshape (dpt .arange (9 , dtype = "f4" ), (3 , 3 ))
164
+ r = dpt .var (x )
165
+ expected = dpt .asarray (6.666666507720947 , dtype = "f4" )
166
+ assert dpt .allclose (r , expected )
167
+
168
+ r1 = dpt .var (x , correction = 3 )
169
+ expected1 = dpt .asarray (10.0 , dtype = "f4" )
170
+ assert dpt .allclose (r1 , expected1 )
171
+
172
+ r = dpt .std (x )
173
+ expected = dpt .sqrt (expected )
174
+ assert dpt .allclose (r , expected )
175
+
176
+ r1 = dpt .std (x , correction = 3 )
177
+ expected1 = dpt .sqrt (expected1 )
178
+ assert dpt .allclose (r1 , expected1 )
179
+
180
+ r = dpt .var (x , axis = 0 )
181
+ expected = dpt .full (x .shape [1 ], 6 , dtype = "f4" )
182
+ assert dpt .allclose (r , expected )
183
+
184
+ r1 = dpt .var (x , axis = 0 , correction = 1 )
185
+ expected1 = dpt .full (x .shape [1 ], 9 , dtype = "f4" )
186
+ assert dpt .allclose (r1 , expected1 )
187
+
188
+ r = dpt .std (x , axis = 0 )
189
+ expected = dpt .sqrt (expected )
190
+ assert dpt .allclose (r , expected )
191
+
192
+ r1 = dpt .std (x , axis = 0 , correction = 1 )
193
+ expected1 = dpt .sqrt (expected1 )
194
+ assert dpt .allclose (r1 , expected1 )
195
+
196
+ r = dpt .var (x , axis = 1 )
197
+ expected = dpt .full (x .shape [0 ], 0.6666666865348816 , dtype = "f4" )
198
+ assert dpt .allclose (r , expected )
199
+
200
+ r1 = dpt .var (x , axis = 1 , correction = 1 )
201
+ expected1 = dpt .ones (x .shape [0 ], dtype = "f4" )
202
+ assert dpt .allclose (r1 , expected1 )
203
+
204
+ r = dpt .std (x , axis = 1 )
205
+ expected = dpt .sqrt (expected )
206
+ assert dpt .allclose (r , expected )
207
+
208
+ r1 = dpt .std (x , axis = 1 , correction = 1 )
209
+ expected1 = dpt .sqrt (expected1 )
210
+ assert dpt .allclose (r1 , expected1 )
211
+
212
+
213
+ def test_var_axis_length_correction ():
214
+ get_queue_or_skip ()
215
+
216
+ x = dpt .reshape (dpt .arange (9 , dtype = "f4" ), (3 , 3 ))
217
+
218
+ r = dpt .var (x , correction = x .size )
219
+ assert dpt .isnan (r )
220
+
221
+ r = dpt .var (x , axis = 0 , correction = x .shape [0 ])
222
+ assert dpt .all (dpt .isnan (r ))
223
+
224
+ r = dpt .var (x , axis = 1 , correction = x .shape [1 ])
225
+ assert dpt .all (dpt .isnan (r ))
226
+
227
+
228
+ def test_stat_function_errors ():
229
+ d = dict ()
230
+ with pytest .raises (TypeError ):
231
+ dpt .var (d )
232
+ with pytest .raises (TypeError ):
233
+ dpt .std (d )
234
+ with pytest .raises (TypeError ):
235
+ dpt .mean (d )
236
+
237
+ x = dpt .empty (1 , dtype = "f4" )
238
+ with pytest .raises (TypeError ):
239
+ dpt .var (x , axis = d )
240
+ with pytest .raises (TypeError ):
241
+ dpt .std (x , axis = d )
242
+ with pytest .raises (TypeError ):
243
+ dpt .mean (x , axis = d )
244
+
245
+ with pytest .raises (TypeError ):
246
+ dpt .var (x , correction = d )
247
+ with pytest .raises (TypeError ):
248
+ dpt .std (x , correction = d )
249
+
250
+ x = dpt .empty (1 , dtype = "c8" )
251
+ with pytest .raises (ValueError ):
252
+ dpt .var (x )
253
+ with pytest .raises (ValueError ):
254
+ dpt .std (x )
0 commit comments