@@ -21,6 +21,13 @@ def stub_module(name):
21
21
def extension_module (name ):
22
22
return name in submodules and name in function_stubs .__all__
23
23
24
+ extension_module_names = {}
25
+ for n in function_stubs .__all__ :
26
+ if extension_module (n ):
27
+ extension_module_names .update ({i : n for i in getattr (function_stubs , n ).__all__ })
28
+
29
+ all_names = function_stubs .__all__ + list (extension_module_names )
30
+
24
31
def array_method (name ):
25
32
return stub_module (name ) == 'array_object'
26
33
@@ -59,10 +66,13 @@ def example_argument(arg, func_name, dtype):
59
66
endpoint = False ,
60
67
fill_value = 1.0 ,
61
68
from_ = int64 ,
69
+ full_matrices = False ,
62
70
k = 1 ,
63
71
keepdims = True ,
64
72
key = 0 ,
65
73
indexing = 'ij' ,
74
+ mode = 'complete' ,
75
+ n = 2 ,
66
76
n_cols = 1 ,
67
77
n_rows = 1 ,
68
78
num = 2 ,
@@ -73,6 +83,7 @@ def example_argument(arg, func_name, dtype):
73
83
return_counts = True ,
74
84
return_index = True ,
75
85
return_inverse = True ,
86
+ rtol = 1e-10 ,
76
87
self = ones ((3 , 3 ), dtype = dtype ),
77
88
shape = (1 , 3 , 3 ),
78
89
shift = 1 ,
@@ -83,6 +94,7 @@ def example_argument(arg, func_name, dtype):
83
94
stop = 1 ,
84
95
to = float64 ,
85
96
type = float64 ,
97
+ upper = True ,
86
98
value = 0 ,
87
99
x1 = ones ((1 , 3 , 3 ), dtype = dtype ),
88
100
x2 = ones ((1 , 3 , 3 ), dtype = dtype ),
@@ -97,8 +109,6 @@ def example_argument(arg, func_name, dtype):
97
109
if func_name == 'squeeze' and arg == 'axis' :
98
110
return 0
99
111
# ones() is not invertible
100
- elif func_name == 'inv' and arg == 'x' :
101
- return eye (3 )
102
112
# finfo requires a float dtype and iinfo requires an int dtype
103
113
elif func_name == 'iinfo' and arg == 'type' :
104
114
return int64
@@ -109,14 +119,24 @@ def example_argument(arg, func_name, dtype):
109
119
# contractible axes or a 2-tuple or axes
110
120
elif func_name == 'tensordot' and arg == 'axes' :
111
121
return 1
122
+ # The inputs to outer() must be 1-dimensional
123
+ elif func_name == 'outer' and arg in ['x1' , 'x2' ]:
124
+ return ones ((3 ,), dtype = dtype )
125
+ # Linear algebra functions tend to error if the input isn't "nice" as
126
+ # a matrix
127
+ elif arg .startswith ('x' ) and func_name in function_stubs .linalg .__all__ :
128
+ return eye (3 )
112
129
return known_args [arg ]
113
130
else :
114
131
raise RuntimeError (f"Don't know how to test argument { arg } . Please update test_signatures.py" )
115
132
116
- @pytest .mark .parametrize ('name' , function_stubs . __all__ )
133
+ @pytest .mark .parametrize ('name' , all_names )
117
134
def test_has_names (name ):
118
135
if extension_module (name ):
119
136
assert hasattr (mod , name ), f'{ mod_name } is missing the { name } extension'
137
+ elif name in extension_module_names :
138
+ extension_mod = extension_module_names [name ]
139
+ assert hasattr (getattr (mod , extension_mod ), name ), f"{ mod_name } is missing the { function_category (name )} extension function { name } ()"
120
140
elif array_method (name ):
121
141
arr = ones ((1 , 1 ))
122
142
if getattr (function_stubs .array_object , name ) is None :
@@ -126,7 +146,7 @@ def test_has_names(name):
126
146
else :
127
147
assert hasattr (mod , name ), f"{ mod_name } is missing the { function_category (name )} function { name } ()"
128
148
129
- @pytest .mark .parametrize ('name' , function_stubs . __all__ )
149
+ @pytest .mark .parametrize ('name' , all_names )
130
150
def test_function_positional_args (name ):
131
151
# Note: We can't actually test that positional arguments are
132
152
# positional-only, as that would require knowing the argument name and
@@ -157,12 +177,16 @@ def test_function_positional_args(name):
157
177
_mod = ones ((), dtype = float64 )
158
178
else :
159
179
_mod = example_argument ('self' , name , dtype )
180
+ stub_func = getattr (function_stubs , name )
181
+ elif name in extension_module_names :
182
+ _mod = getattr (mod , extension_module_names [name ])
183
+ stub_func = getattr (getattr (function_stubs , extension_module_names [name ]), name )
160
184
else :
161
185
_mod = mod
186
+ stub_func = getattr (function_stubs , name )
162
187
163
188
if not hasattr (_mod , name ):
164
189
pytest .skip (f"{ mod_name } does not have { name } (), skipping." )
165
- stub_func = getattr (function_stubs , name )
166
190
if stub_func is None :
167
191
# TODO: Can we make this skip the parameterization entirely?
168
192
pytest .skip (f"{ name } is not a function, skipping." )
@@ -198,19 +222,23 @@ def test_function_positional_args(name):
198
222
# NumPy ufuncs raise ValueError instead of TypeError
199
223
raises ((TypeError , ValueError ), lambda : mod_func (* args [:n ]), f"{ name } () should not accept { n } positional arguments" )
200
224
201
- @pytest .mark .parametrize ('name' , function_stubs . __all__ )
225
+ @pytest .mark .parametrize ('name' , all_names )
202
226
def test_function_keyword_only_args (name ):
203
227
if extension_module (name ):
204
228
return
205
229
206
230
if array_method (name ):
207
231
_mod = ones ((1 , 1 ))
232
+ stub_func = getattr (function_stubs , name )
233
+ elif name in extension_module_names :
234
+ _mod = getattr (mod , extension_module_names [name ])
235
+ stub_func = getattr (getattr (function_stubs , extension_module_names [name ]), name )
208
236
else :
209
237
_mod = mod
238
+ stub_func = getattr (function_stubs , name )
210
239
211
240
if not hasattr (_mod , name ):
212
241
pytest .skip (f"{ mod_name } does not have { name } (), skipping." )
213
- stub_func = getattr (function_stubs , name )
214
242
if stub_func is None :
215
243
# TODO: Can we make this skip the parameterization entirely?
216
244
pytest .skip (f"{ name } is not a function, skipping." )
0 commit comments