diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 131726b7..3864d426 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -1,5 +1,4 @@ import re -import itertools from contextlib import contextmanager from functools import reduce, wraps import math @@ -309,18 +308,14 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes( # For now, just generate stacks of diagonal matrices. n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),) stack_shape = draw(stack_shapes) - shape = stack_shape + (n, n) - d = draw(arrays(dtypes, shape=n*prod(stack_shape), + d = draw(arrays(dtypes, shape=(*stack_shape, 1, n), elements=dict(allow_nan=False, allow_infinity=False))) # Functions that require invertible matrices may do anything when it is # singular, including raising an exception, so we make sure the diagonals # are sufficiently nonzero to avoid any numerical issues. assume(xp.all(xp.abs(d) > 0.5)) - - a = xp.zeros(shape) - for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))): - a[idx + (i, i)] = d[j] - return a + diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1)) + return xp.where(diag_mask, d, xp.zeros_like(d)) # TODO: Better name @composite