Skip to content

Commit 1303ef5

Browse files
committed
Update the test_signatures tests to also test the linear algebra extension
We still need to figure out a clean way to let people only selectively enable these tests, as the extension is optional
1 parent fe46b25 commit 1303ef5

File tree

1 file changed

+35
-7
lines changed

1 file changed

+35
-7
lines changed

array_api_tests/test_signatures.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ def stub_module(name):
2121
def extension_module(name):
2222
return name in submodules and name in function_stubs.__all__
2323

24+
extension_module_names = {}
25+
for n in function_stubs.__all__:
26+
if extension_module(n):
27+
extension_module_names.update({i: n for i in getattr(function_stubs, n).__all__})
28+
29+
all_names = function_stubs.__all__ + list(extension_module_names)
30+
2431
def array_method(name):
2532
return stub_module(name) == 'array_object'
2633

@@ -59,10 +66,13 @@ def example_argument(arg, func_name, dtype):
5966
endpoint=False,
6067
fill_value=1.0,
6168
from_=int64,
69+
full_matrices=False,
6270
k=1,
6371
keepdims=True,
6472
key=0,
6573
indexing='ij',
74+
mode='complete',
75+
n=2,
6676
n_cols=1,
6777
n_rows=1,
6878
num=2,
@@ -73,6 +83,7 @@ def example_argument(arg, func_name, dtype):
7383
return_counts=True,
7484
return_index=True,
7585
return_inverse=True,
86+
rtol=1e-10,
7687
self=ones((3, 3), dtype=dtype),
7788
shape=(1, 3, 3),
7889
shift=1,
@@ -83,6 +94,7 @@ def example_argument(arg, func_name, dtype):
8394
stop=1,
8495
to=float64,
8596
type=float64,
97+
upper=True,
8698
value=0,
8799
x1=ones((1, 3, 3), dtype=dtype),
88100
x2=ones((1, 3, 3), dtype=dtype),
@@ -97,8 +109,6 @@ def example_argument(arg, func_name, dtype):
97109
if func_name == 'squeeze' and arg == 'axis':
98110
return 0
99111
# ones() is not invertible
100-
elif func_name == 'inv' and arg == 'x':
101-
return eye(3)
102112
# finfo requires a float dtype and iinfo requires an int dtype
103113
elif func_name == 'iinfo' and arg == 'type':
104114
return int64
@@ -109,14 +119,24 @@ def example_argument(arg, func_name, dtype):
109119
# contractible axes or a 2-tuple or axes
110120
elif func_name == 'tensordot' and arg == 'axes':
111121
return 1
122+
# The inputs to outer() must be 1-dimensional
123+
elif func_name == 'outer' and arg in ['x1', 'x2']:
124+
return ones((3,), dtype=dtype)
125+
# Linear algebra functions tend to error if the input isn't "nice" as
126+
# a matrix
127+
elif arg.startswith('x') and func_name in function_stubs.linalg.__all__:
128+
return eye(3)
112129
return known_args[arg]
113130
else:
114131
raise RuntimeError(f"Don't know how to test argument {arg}. Please update test_signatures.py")
115132

116-
@pytest.mark.parametrize('name', function_stubs.__all__)
133+
@pytest.mark.parametrize('name', all_names)
117134
def test_has_names(name):
118135
if extension_module(name):
119136
assert hasattr(mod, name), f'{mod_name} is missing the {name} extension'
137+
elif name in extension_module_names:
138+
extension_mod = extension_module_names[name]
139+
assert hasattr(getattr(mod, extension_mod), name), f"{mod_name} is missing the {function_category(name)} extension function {name}()"
120140
elif array_method(name):
121141
arr = ones((1, 1))
122142
if getattr(function_stubs.array_object, name) is None:
@@ -126,7 +146,7 @@ def test_has_names(name):
126146
else:
127147
assert hasattr(mod, name), f"{mod_name} is missing the {function_category(name)} function {name}()"
128148

129-
@pytest.mark.parametrize('name', function_stubs.__all__)
149+
@pytest.mark.parametrize('name', all_names)
130150
def test_function_positional_args(name):
131151
# Note: We can't actually test that positional arguments are
132152
# positional-only, as that would require knowing the argument name and
@@ -157,12 +177,16 @@ def test_function_positional_args(name):
157177
_mod = ones((), dtype=float64)
158178
else:
159179
_mod = example_argument('self', name, dtype)
180+
stub_func = getattr(function_stubs, name)
181+
elif name in extension_module_names:
182+
_mod = getattr(mod, extension_module_names[name])
183+
stub_func = getattr(getattr(function_stubs, extension_module_names[name]), name)
160184
else:
161185
_mod = mod
186+
stub_func = getattr(function_stubs, name)
162187

163188
if not hasattr(_mod, name):
164189
pytest.skip(f"{mod_name} does not have {name}(), skipping.")
165-
stub_func = getattr(function_stubs, name)
166190
if stub_func is None:
167191
# TODO: Can we make this skip the parameterization entirely?
168192
pytest.skip(f"{name} is not a function, skipping.")
@@ -198,19 +222,23 @@ def test_function_positional_args(name):
198222
# NumPy ufuncs raise ValueError instead of TypeError
199223
raises((TypeError, ValueError), lambda: mod_func(*args[:n]), f"{name}() should not accept {n} positional arguments")
200224

201-
@pytest.mark.parametrize('name', function_stubs.__all__)
225+
@pytest.mark.parametrize('name', all_names)
202226
def test_function_keyword_only_args(name):
203227
if extension_module(name):
204228
return
205229

206230
if array_method(name):
207231
_mod = ones((1, 1))
232+
stub_func = getattr(function_stubs, name)
233+
elif name in extension_module_names:
234+
_mod = getattr(mod, extension_module_names[name])
235+
stub_func = getattr(getattr(function_stubs, extension_module_names[name]), name)
208236
else:
209237
_mod = mod
238+
stub_func = getattr(function_stubs, name)
210239

211240
if not hasattr(_mod, name):
212241
pytest.skip(f"{mod_name} does not have {name}(), skipping.")
213-
stub_func = getattr(function_stubs, name)
214242
if stub_func is None:
215243
# TODO: Can we make this skip the parameterization entirely?
216244
pytest.skip(f"{name} is not a function, skipping.")

0 commit comments

Comments
 (0)