1
- from array_api_tests .typing import DataType
2
1
import math
3
2
from typing import Union
4
3
from itertools import takewhile , count
11
10
from . import dtype_helpers as dh
12
11
from . import pytest_helpers as ph
13
12
from . import xps
13
+ from .typing import Shape , DataType
14
14
15
15
16
16
def assert_default_float (func_name : str , dtype : DataType ):
@@ -47,6 +47,19 @@ def assert_kw_dtype(
47
47
assert out_dtype == kw_dtype , msg
48
48
49
49
50
+ def assert_shape (
51
+ func_name : str ,
52
+ out_shape : Shape ,
53
+ expected : Union [int , Shape ],
54
+ ** kw ,
55
+ ):
56
+ f_kw = ", " .join (f"{ k } ={ v } " for k , v in kw .items ())
57
+ msg = f"out.shape={ out_shape } , but should be { expected } [{ func_name } ({ f_kw } )]"
58
+ if isinstance (expected , int ):
59
+ expected = (expected ,)
60
+ assert out_shape == expected , msg
61
+
62
+
50
63
51
64
# Testing xp.arange() requires bounding the start/stop/step arguments to only
52
65
# test argument combinations compliant with the Array API, as well as to not
@@ -171,9 +184,7 @@ def test_empty(shape, kw):
171
184
assert_default_float ("empty" , out .dtype )
172
185
else :
173
186
assert_kw_dtype ("empty" , kw ["dtype" ], out .dtype )
174
- if isinstance (shape , int ):
175
- shape = (shape ,)
176
- assert out .shape == shape , f"{ shape = } , but empty() returned an array with shape { out .shape } "
187
+ assert_shape ("empty" , out .shape , shape , shape = shape )
177
188
178
189
179
190
@given (
@@ -186,7 +197,7 @@ def test_empty_like(x, kw):
186
197
ph .assert_dtype ("empty_like" , (x .dtype ,), out .dtype )
187
198
else :
188
199
assert_kw_dtype ("empty_like" , kw ["dtype" ], out .dtype )
189
- assert out .shape == x . shape , f" { x .shape = } , but empty_like() returned an array with shape { out . shape } "
200
+ assert_shape ( "empty_like" , out .shape , x .shape )
190
201
191
202
192
203
@given (
@@ -204,7 +215,7 @@ def test_eye(n_rows, n_cols, kw):
204
215
else :
205
216
assert_kw_dtype ("eye" , kw ["dtype" ], out .dtype )
206
217
_n_cols = n_rows if n_cols is None else n_cols
207
- assert out .shape == (n_rows , _n_cols ), "eye() produced an array with incorrect shape"
218
+ assert_shape ( "eye" , out .shape , (n_rows , _n_cols ), n_rows = n_rows , n_cols = n_cols )
208
219
for i in range (n_rows ):
209
220
for j in range (_n_cols ):
210
221
if j - i == kw .get ("k" , 0 ):
@@ -254,7 +265,7 @@ def test_full(shape, fill_value, kw):
254
265
assert_default_float ("full" , out .dtype )
255
266
else :
256
267
assert_kw_dtype ("full" , kw ["dtype" ], out .dtype )
257
- assert out .shape == shape , f" { shape = } , but full() returned an array with shape { out . shape } "
268
+ assert_shape ( "full" , out .shape , shape , shape = shape )
258
269
if dh .is_float_dtype (out .dtype ) and math .isnan (fill_value ):
259
270
assert ah .all (ah .isnan (out )), "full() array did not equal the fill value"
260
271
else :
@@ -280,7 +291,8 @@ def test_full_like(x, fill_value, kw):
280
291
ph .assert_dtype ("full_like" , (x .dtype ,), out .dtype )
281
292
else :
282
293
assert_kw_dtype ("full_like" , kw ["dtype" ], out .dtype )
283
- assert out .shape == x .shape , "{x.shape=}, but full_like() returned an array with shape {out.shape}"
294
+
295
+ assert_shape ("full_like" , out .shape , x .shape )
284
296
if dh .is_float_dtype (dtype ) and math .isnan (fill_value ):
285
297
assert ah .all (ah .isnan (out )), "full_like() array did not equal the fill value"
286
298
else :
@@ -309,7 +321,7 @@ def test_linspace(start, stop, num, dtype, endpoint):
309
321
else :
310
322
assert_kw_dtype ("linspace" , dtype , a .dtype )
311
323
312
- assert a .shape == ( num ,), "linspace() did not return an array with the correct shape"
324
+ assert_shape ( "linspace" , a .shape , num , start = stop , stop = stop , num = num )
313
325
314
326
if endpoint in [None , True ]:
315
327
if num > 1 :
@@ -347,7 +359,7 @@ def test_ones(shape, kw):
347
359
assert_default_float ("ones" , out .dtype )
348
360
else :
349
361
assert_kw_dtype ("ones" , kw ["dtype" ], out .dtype )
350
- assert out .shape == shape , f" { shape = } , but empty() returned an array with shape { out . shape } "
362
+ assert_shape ( "ones" , out .shape , shape , shape = shape )
351
363
dtype = kw .get ("dtype" , None ) or dh .default_float
352
364
assert ah .all (ah .equal (out , ah .asarray (make_one (dtype ), dtype = dtype ))), "ones() array did not equal 1"
353
365
@@ -362,7 +374,7 @@ def test_ones_like(x, kw):
362
374
ph .assert_dtype ("ones_like" , (x .dtype ,), out .dtype )
363
375
else :
364
376
assert_kw_dtype ("ones_like" , kw ["dtype" ], out .dtype )
365
- assert out .shape == x . shape , "{ x.shape=}, but ones_like() returned an array with shape {out.shape}"
377
+ assert_shape ( "ones_like" , out .shape , x .shape )
366
378
dtype = kw .get ("dtype" , None ) or x .dtype
367
379
assert ah .all (ah .equal (out , ah .asarray (make_one (dtype ), dtype = dtype ))), "ones_like() array elements did not equal 1"
368
380
@@ -383,7 +395,7 @@ def test_zeros(shape, kw):
383
395
assert_default_float ("zeros" , out .dtype )
384
396
else :
385
397
assert_kw_dtype ("zeros" , kw ["dtype" ], out .dtype )
386
- assert out .shape == shape , "zeros() produced an array with incorrect shape"
398
+ assert_shape ( "zeros" , out .shape , shape , shape = shape )
387
399
dtype = kw .get ("dtype" , None ) or dh .default_float
388
400
assert ah .all (ah .equal (out , ah .asarray (make_zero (dtype ), dtype = dtype ))), "zeros() array did not equal 0"
389
401
@@ -398,7 +410,6 @@ def test_zeros_like(x, kw):
398
410
ph .assert_dtype ("zeros_like" , (x .dtype ,), out .dtype )
399
411
else :
400
412
assert_kw_dtype ("zeros_like" , kw ["dtype" ], out .dtype )
401
- assert out .shape == x . shape , "{ x.shape=}, but xp.zeros_like() returned an array with shape {out.shape}"
413
+ assert_shape ( "zeros_like" , out .shape , x .shape )
402
414
dtype = kw .get ("dtype" , None ) or x .dtype
403
415
assert ah .all (ah .equal (out , ah .asarray (make_zero (dtype ), dtype = out .dtype ))), "xp.zeros_like() array elements did not ah.all xp.equal 0"
404
-
0 commit comments