Skip to content

Commit 15c06ff

Browse files
committed
Use actual Solve Op to infer output dtype
CholSolve outputs a different dtype than basic Solve in Scipy==1.15
1 parent 33c9400 commit 15c06ff

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

pytensor/tensor/slinalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,10 @@ def make_node(self, A, b):
259259
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
260260

261261
# Infer dtype by solving the most simple case with 1x1 matrices
262-
o_dtype = scipy.linalg.solve(
263-
np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)
264-
).dtype
262+
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)]
263+
out_arr = [[None]]
264+
self.perform(None, inp_arr, out_arr)
265+
o_dtype = out_arr[0][0].dtype
265266
x = tensor(dtype=o_dtype, shape=b.type.shape)
266267
return Apply(self, [A, b], [x])
267268

tests/tensor/test_slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def test_solve_dtype(self):
450450
fn = function([A, b], x)
451451
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
452452

453-
assert x.dtype == x_result.dtype
453+
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)
454454

455455

456456
def test_cho_solve():

0 commit comments

Comments
 (0)