Skip to content

Commit 7e613fb

Browse files
committed
Rudimentary re-implementation of test_linspace
1 parent 266d162 commit 7e613fb

File tree

1 file changed

+55
-34
lines changed

1 file changed

+55
-34
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -299,48 +299,69 @@ def test_full_like(x, fill_value, kw):
299299
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), "full_like() array did not equal the fill value"
300300

301301

302-
@given(hh.scalars(hh.shared_dtypes, finite=True),
303-
hh.scalars(hh.shared_dtypes, finite=True),
304-
hh.sizes,
305-
st.one_of(st.none(), hh.shared_dtypes),
306-
st.one_of(st.none(), st.booleans()),)
307-
def test_linspace(start, stop, num, dtype, endpoint):
308-
# Skip on int start or stop that cannot be exactly represented as a float,
309-
# since we do not have good approx_equal helpers yet.
310-
if ((dtype is None or dh.is_float_dtype(dtype))
311-
and ((isinstance(start, int) and not ah.isintegral(xp.asarray(start, dtype=dtype)))
312-
or (isinstance(stop, int) and not ah.isintegral(xp.asarray(stop, dtype=dtype))))):
313-
assume(False)
314-
315-
kwargs = {k: v for k, v in {'dtype': dtype, 'endpoint': endpoint}.items()
316-
if v is not None}
317-
a = xp.linspace(start, stop, num, **kwargs)
302+
finite_kw = {"allow_nan": False, "allow_infinity": False}
318303

319-
if dtype is None:
320-
assert_default_float("linspace", a.dtype)
304+
305+
@st.composite
306+
def int_stops(draw, start: int, min_gap: int, m: int, M: int):
307+
sign = draw(st.booleans().map(int))
308+
max_gap = abs(M - m)
309+
max_int = math.floor(math.sqrt(max_gap))
310+
gap = draw(
311+
st.just(0),
312+
st.integers(1, max_int).map(lambda n: min_gap ** n)
313+
)
314+
stop = start + sign * gap
315+
assume(m <= stop <= M)
316+
return stop
317+
318+
319+
@given(
320+
num=hh.sizes,
321+
dtype=st.none() | xps.numeric_dtypes(),
322+
endpoint=st.booleans(),
323+
data=st.data(),
324+
)
325+
def test_linspace(num, dtype, endpoint, data):
326+
_dtype = dh.default_float if dtype is None else dtype
327+
328+
start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start")
329+
if dh.is_float_dtype(_dtype):
330+
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
331+
# avoid overflow errors
332+
delta = ah.asarray(stop - start, dtype=_dtype)
333+
assume(not ah.isnan(delta))
321334
else:
322-
assert_kw_dtype("linspace", dtype, a.dtype)
335+
if num == 0:
336+
stop = start
337+
else:
338+
min_gap = num
339+
if endpoint:
340+
min_gap += 1
341+
m, M = dh.dtype_ranges[_dtype]
342+
stop = data.draw(int_stops(start, min_gap, m, M), label="stop")
323343

324-
assert_shape("linspace", a.shape, num, start=stop, stop=stop, num=num)
344+
out = xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
325345

326-
if endpoint in [None, True]:
346+
assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num)
347+
348+
if endpoint:
327349
if num > 1:
328-
assert ah.all(ah.equal(a[-1], ah.asarray(stop, dtype=a.dtype))), "linspace() produced an array that does not include the endpoint"
350+
assert ah.equal(
351+
out[-1], ah.asarray(stop, dtype=out.dtype)
352+
), f"out[-1]={out[-1]}, but should be {stop=} [linspace()]"
329353
else:
330-
# linspace(..., num, endpoint=False) is the same as the first num
331-
# elements of linspace(..., num+1, endpoint=True)
332-
b = xp.linspace(start, stop, num + 1, **{**kwargs, 'endpoint': True})
333-
ah.assert_exactly_equal(b[:-1], a)
354+
# linspace(..., num, endpoint=True) should return an array equivalent to
355+
# the first num elements when endpoint=False
356+
expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True)
357+
expected = expected[:-1]
358+
ah.assert_exactly_equal(out, expected)
334359

335360
if num > 0:
336-
# We need to cast start to dtype
337-
assert ah.all(ah.equal(a[0], ah.asarray(start, dtype=a.dtype))), "xp.linspace() produced an array that does not start with the start"
338-
339-
# TODO: This requires an assert_approx_equal function
340-
341-
# n = num - 1 if endpoint in [None, True] else num
342-
# for i in range(1, num):
343-
# assert ah.all(ah.equal(a[i], ah.full((), i*(stop - start)/n + start, dtype=dtype))), f"linspace() produced an array with an incorrect value at index {i}"
361+
assert ah.equal(
362+
out[0], ah.asarray(start, dtype=out.dtype)
363+
), f"out[0]={out[0]}, but should be {start=} [linspace()]"
364+
# TODO: array assertions ala test_arange
344365

345366

346367
def make_one(dtype):

0 commit comments

Comments
 (0)