@@ -299,48 +299,69 @@ def test_full_like(x, fill_value, kw):
299
299
assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), "full_like() array did not equal the fill value"
300
300
301
301
302
- @given (hh .scalars (hh .shared_dtypes , finite = True ),
303
- hh .scalars (hh .shared_dtypes , finite = True ),
304
- hh .sizes ,
305
- st .one_of (st .none (), hh .shared_dtypes ),
306
- st .one_of (st .none (), st .booleans ()),)
307
- def test_linspace (start , stop , num , dtype , endpoint ):
308
- # Skip on int start or stop that cannot be exactly represented as a float,
309
- # since we do not have good approx_equal helpers yet.
310
- if ((dtype is None or dh .is_float_dtype (dtype ))
311
- and ((isinstance (start , int ) and not ah .isintegral (xp .asarray (start , dtype = dtype )))
312
- or (isinstance (stop , int ) and not ah .isintegral (xp .asarray (stop , dtype = dtype ))))):
313
- assume (False )
314
-
315
- kwargs = {k : v for k , v in {'dtype' : dtype , 'endpoint' : endpoint }.items ()
316
- if v is not None }
317
- a = xp .linspace (start , stop , num , ** kwargs )
302
+ finite_kw = {"allow_nan" : False , "allow_infinity" : False }
318
303
319
- if dtype is None :
320
- assert_default_float ("linspace" , a .dtype )
304
+
305
+ @st .composite
306
+ def int_stops (draw , start : int , min_gap : int , m : int , M : int ):
307
+ sign = draw (st .booleans ().map (int ))
308
+ max_gap = abs (M - m )
309
+ max_int = math .floor (math .sqrt (max_gap ))
310
+ gap = draw (
311
+ st .just (0 ),
312
+ st .integers (1 , max_int ).map (lambda n : min_gap ** n )
313
+ )
314
+ stop = start + sign * gap
315
+ assume (m <= stop <= M )
316
+ return stop
317
+
318
+
319
+ @given (
320
+ num = hh .sizes ,
321
+ dtype = st .none () | xps .numeric_dtypes (),
322
+ endpoint = st .booleans (),
323
+ data = st .data (),
324
+ )
325
+ def test_linspace (num , dtype , endpoint , data ):
326
+ _dtype = dh .default_float if dtype is None else dtype
327
+
328
+ start = data .draw (xps .from_dtype (_dtype , ** finite_kw ), label = "start" )
329
+ if dh .is_float_dtype (_dtype ):
330
+ stop = data .draw (xps .from_dtype (_dtype , ** finite_kw ), label = "stop" )
331
+ # avoid overflow errors
332
+ delta = ah .asarray (stop - start , dtype = _dtype )
333
+ assume (not ah .isnan (delta ))
321
334
else :
322
- assert_kw_dtype ("linspace" , dtype , a .dtype )
335
+ if num == 0 :
336
+ stop = start
337
+ else :
338
+ min_gap = num
339
+ if endpoint :
340
+ min_gap += 1
341
+ m , M = dh .dtype_ranges [_dtype ]
342
+ stop = data .draw (int_stops (start , min_gap , m , M ), label = "stop" )
323
343
324
- assert_shape ( " linspace" , a . shape , num , start = stop , stop = stop , num = num )
344
+ out = xp . linspace ( start , stop , num , dtype = dtype , endpoint = endpoint )
325
345
326
- if endpoint in [None , True ]:
346
+ assert_shape ("linspace" , out .shape , num , start = stop , stop = stop , num = num )
347
+
348
+ if endpoint :
327
349
if num > 1 :
328
- assert ah .all (ah .equal (a [- 1 ], ah .asarray (stop , dtype = a .dtype ))), "linspace() produced an array that does not include the endpoint"
350
+ assert ah .equal (
351
+ out [- 1 ], ah .asarray (stop , dtype = out .dtype )
352
+ ), f"out[-1]={ out [- 1 ]} , but should be { stop = } [linspace()]"
329
353
else :
330
- # linspace(..., num, endpoint=False) is the same as the first num
331
- # elements of linspace(..., num+1, endpoint=True)
332
- b = xp .linspace (start , stop , num + 1 , ** {** kwargs , 'endpoint' : True })
333
- ah .assert_exactly_equal (b [:- 1 ], a )
354
+ # linspace(..., num, endpoint=True) should return an array equivalent to
355
+ # the first num elements when endpoint=False
356
+ expected = xp .linspace (start , stop , num + 1 , dtype = dtype , endpoint = True )
357
+ expected = expected [:- 1 ]
358
+ ah .assert_exactly_equal (out , expected )
334
359
335
360
if num > 0 :
336
- # We need to cast start to dtype
337
- assert ah .all (ah .equal (a [0 ], ah .asarray (start , dtype = a .dtype ))), "xp.linspace() produced an array that does not start with the start"
338
-
339
- # TODO: This requires an assert_approx_equal function
340
-
341
- # n = num - 1 if endpoint in [None, True] else num
342
- # for i in range(1, num):
343
- # assert ah.all(ah.equal(a[i], ah.full((), i*(stop - start)/n + start, dtype=dtype))), f"linspace() produced an array with an incorrect value at index {i}"
361
+ assert ah .equal (
362
+ out [0 ], ah .asarray (start , dtype = out .dtype )
363
+ ), f"out[0]={ out [0 ]} , but should be { start = } [linspace()]"
364
+ # TODO: array assertions ala test_arange
344
365
345
366
346
367
def make_one (dtype ):
0 commit comments