@@ -169,7 +169,12 @@ def test_eigvalsh_grad():
169
169
)
170
170
171
171
172
- class TestSolveBase (utt .InferShapeTester ):
172
+ class TestSolveBase :
173
+ class SolveTest (SolveBase ):
174
+ def perform (self , node , inputs , outputs ):
175
+ A , b = inputs
176
+ outputs [0 ][0 ] = scipy .linalg .solve (A , b )
177
+
173
178
@pytest .mark .parametrize (
174
179
"A_func, b_func, error_message" ,
175
180
[
@@ -191,16 +196,16 @@ def test_make_node(self, A_func, b_func, error_message):
191
196
with pytest .raises (ValueError , match = error_message ):
192
197
A = A_func ()
193
198
b = b_func ()
194
- SolveBase (b_ndim = 2 )(A , b )
199
+ self . SolveTest (b_ndim = 2 )(A , b )
195
200
196
201
def test__repr__ (self ):
197
202
np .random .default_rng (utt .fetch_seed ())
198
203
A = matrix ()
199
204
b = matrix ()
200
- y = SolveBase (b_ndim = 2 )(A , b )
205
+ y = self . SolveTest (b_ndim = 2 )(A , b )
201
206
assert (
202
207
y .__repr__ ()
203
- == "SolveBase {lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
208
+ == "SolveTest {lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
204
209
)
205
210
206
211
@@ -239,8 +244,9 @@ def test_correctness(self):
239
244
A_val = np .asarray (rng .random ((5 , 5 )), dtype = config .floatX )
240
245
A_val = np .dot (A_val .transpose (), A_val )
241
246
242
- assert np .allclose (
243
- scipy .linalg .solve (A_val , b_val ), gen_solve_func (A_val , b_val )
247
+ np .testing .assert_allclose (
248
+ scipy .linalg .solve (A_val , b_val , assume_a = "gen" ),
249
+ gen_solve_func (A_val , b_val ),
244
250
)
245
251
246
252
A_undef = np .array (
@@ -253,7 +259,7 @@ def test_correctness(self):
253
259
],
254
260
dtype = config .floatX ,
255
261
)
256
- assert np .allclose (
262
+ np .testing . assert_allclose (
257
263
scipy .linalg .solve (A_undef , b_val ), gen_solve_func (A_undef , b_val )
258
264
)
259
265
@@ -450,7 +456,7 @@ def test_solve_dtype(self):
450
456
fn = function ([A , b ], x )
451
457
x_result = fn (A_val .astype (A_dtype ), b_val .astype (b_dtype ))
452
458
453
- assert x .dtype == x_result .dtype
459
+ assert x .dtype == x_result .dtype , ( A_dtype , b_dtype )
454
460
455
461
456
462
def test_cho_solve ():
0 commit comments