Skip to content

Commit 266d162

Browse files
committed
Refactor shape assertions with assert_shape
1 parent 31cabac commit 266d162

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from array_api_tests.typing import DataType
21
import math
32
from typing import Union
43
from itertools import takewhile, count
@@ -11,6 +10,7 @@
1110
from . import dtype_helpers as dh
1211
from . import pytest_helpers as ph
1312
from . import xps
13+
from .typing import Shape, DataType
1414

1515

1616
def assert_default_float(func_name: str, dtype: DataType):
@@ -47,6 +47,19 @@ def assert_kw_dtype(
4747
assert out_dtype == kw_dtype, msg
4848

4949

50+
def assert_shape(
51+
func_name: str,
52+
out_shape: Shape,
53+
expected: Union[int, Shape],
54+
**kw,
55+
):
56+
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
57+
msg = f"out.shape={out_shape}, but should be {expected} [{func_name}({f_kw})]"
58+
if isinstance(expected, int):
59+
expected = (expected,)
60+
assert out_shape == expected, msg
61+
62+
5063

5164
# Testing xp.arange() requires bounding the start/stop/step arguments to only
5265
# test argument combinations compliant with the Array API, as well as to not
@@ -171,9 +184,7 @@ def test_empty(shape, kw):
171184
assert_default_float("empty", out.dtype)
172185
else:
173186
assert_kw_dtype("empty", kw["dtype"], out.dtype)
174-
if isinstance(shape, int):
175-
shape = (shape,)
176-
assert out.shape == shape, f"{shape=}, but empty() returned an array with shape {out.shape}"
187+
assert_shape("empty", out.shape, shape, shape=shape)
177188

178189

179190
@given(
@@ -186,7 +197,7 @@ def test_empty_like(x, kw):
186197
ph.assert_dtype("empty_like", (x.dtype,), out.dtype)
187198
else:
188199
assert_kw_dtype("empty_like", kw["dtype"], out.dtype)
189-
assert out.shape == x.shape, f"{x.shape=}, but empty_like() returned an array with shape {out.shape}"
200+
assert_shape("empty_like", out.shape, x.shape)
190201

191202

192203
@given(
@@ -204,7 +215,7 @@ def test_eye(n_rows, n_cols, kw):
204215
else:
205216
assert_kw_dtype("eye", kw["dtype"], out.dtype)
206217
_n_cols = n_rows if n_cols is None else n_cols
207-
assert out.shape == (n_rows, _n_cols), "eye() produced an array with incorrect shape"
218+
assert_shape("eye", out.shape, (n_rows, _n_cols), n_rows=n_rows, n_cols=n_cols)
208219
for i in range(n_rows):
209220
for j in range(_n_cols):
210221
if j - i == kw.get("k", 0):
@@ -254,7 +265,7 @@ def test_full(shape, fill_value, kw):
254265
assert_default_float("full", out.dtype)
255266
else:
256267
assert_kw_dtype("full", kw["dtype"], out.dtype)
257-
assert out.shape == shape, f"{shape=}, but full() returned an array with shape {out.shape}"
268+
assert_shape("full", out.shape, shape, shape=shape)
258269
if dh.is_float_dtype(out.dtype) and math.isnan(fill_value):
259270
assert ah.all(ah.isnan(out)), "full() array did not equal the fill value"
260271
else:
@@ -280,7 +291,8 @@ def test_full_like(x, fill_value, kw):
280291
ph.assert_dtype("full_like", (x.dtype,), out.dtype)
281292
else:
282293
assert_kw_dtype("full_like", kw["dtype"], out.dtype)
283-
assert out.shape == x.shape, "{x.shape=}, but full_like() returned an array with shape {out.shape}"
294+
295+
assert_shape("full_like", out.shape, x.shape)
284296
if dh.is_float_dtype(dtype) and math.isnan(fill_value):
285297
assert ah.all(ah.isnan(out)), "full_like() array did not equal the fill value"
286298
else:
@@ -309,7 +321,7 @@ def test_linspace(start, stop, num, dtype, endpoint):
309321
else:
310322
assert_kw_dtype("linspace", dtype, a.dtype)
311323

312-
assert a.shape == (num,), "linspace() did not return an array with the correct shape"
324+
assert_shape("linspace", a.shape, num, start=stop, stop=stop, num=num)
313325

314326
if endpoint in [None, True]:
315327
if num > 1:
@@ -347,7 +359,7 @@ def test_ones(shape, kw):
347359
assert_default_float("ones", out.dtype)
348360
else:
349361
assert_kw_dtype("ones", kw["dtype"], out.dtype)
350-
assert out.shape == shape, f"{shape=}, but empty() returned an array with shape {out.shape}"
362+
assert_shape("ones", out.shape, shape, shape=shape)
351363
dtype = kw.get("dtype", None) or dh.default_float
352364
assert ah.all(ah.equal(out, ah.asarray(make_one(dtype), dtype=dtype))), "ones() array did not equal 1"
353365

@@ -362,7 +374,7 @@ def test_ones_like(x, kw):
362374
ph.assert_dtype("ones_like", (x.dtype,), out.dtype)
363375
else:
364376
assert_kw_dtype("ones_like", kw["dtype"], out.dtype)
365-
assert out.shape == x.shape, "{x.shape=}, but ones_like() returned an array with shape {out.shape}"
377+
assert_shape("ones_like", out.shape, x.shape)
366378
dtype = kw.get("dtype", None) or x.dtype
367379
assert ah.all(ah.equal(out, ah.asarray(make_one(dtype), dtype=dtype))), "ones_like() array elements did not equal 1"
368380

@@ -383,7 +395,7 @@ def test_zeros(shape, kw):
383395
assert_default_float("zeros", out.dtype)
384396
else:
385397
assert_kw_dtype("zeros", kw["dtype"], out.dtype)
386-
assert out.shape == shape, "zeros() produced an array with incorrect shape"
398+
assert_shape("zeros", out.shape, shape, shape=shape)
387399
dtype = kw.get("dtype", None) or dh.default_float
388400
assert ah.all(ah.equal(out, ah.asarray(make_zero(dtype), dtype=dtype))), "zeros() array did not equal 0"
389401

@@ -398,7 +410,6 @@ def test_zeros_like(x, kw):
398410
ph.assert_dtype("zeros_like", (x.dtype,), out.dtype)
399411
else:
400412
assert_kw_dtype("zeros_like", kw["dtype"], out.dtype)
401-
assert out.shape == x.shape, "{x.shape=}, but xp.zeros_like() returned an array with shape {out.shape}"
413+
assert_shape("zeros_like", out.shape, x.shape)
402414
dtype = kw.get("dtype", None) or x.dtype
403415
assert ah.all(ah.equal(out, ah.asarray(make_zero(dtype), dtype=out.dtype))), "xp.zeros_like() array elements did not ah.all xp.equal 0"
404-

0 commit comments

Comments
 (0)