Skip to content

Commit b06b0c7

Browse files
Add shape_unsafe tag and abort rewrite if static shapes are invalid
1 parent 5be594a commit b06b0c7

File tree

2 files changed

+55
-18
lines changed

2 files changed

+55
-18
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -857,8 +857,8 @@ def rewrite_det_kronecker(fgraph, node):
857857
return [det_final]
858858

859859

860-
@register_canonicalize
861-
@register_stabilize
860+
@register_canonicalize("shape_unsafe")
861+
@register_stabilize("shape_unsafe")
862862
@node_rewriter([Blockwise])
863863
def rewrite_solve_kron_to_solve(fgraph, node):
864864
"""
@@ -896,6 +896,20 @@ def rewrite_solve_kron_to_solve(fgraph, node):
896896

897897
x1, x2 = A.owner.inputs
898898

899+
# If x1 and x2 have statically known core shapes, check that they are square. If not, the rewrite will be invalid.
900+
# We will proceed if they are unknown, but this makes the rewrite shape unsafe.
901+
x1_core_shapes = x1.type.shape[-2:]
902+
x2_core_shapes = x2.type.shape[-2:]
903+
904+
if (
905+
all(shape is not None for shape in x1_core_shapes)
906+
and x1_core_shapes[-1] != x1_core_shapes[-2]
907+
) or (
908+
all(shape is not None for shape in x2_core_shapes)
909+
and x2_core_shapes[-1] != x2_core_shapes[-2]
910+
):
911+
return None
912+
899913
m, n = x1.shape[-2], x2.shape[-2]
900914
batch_shapes = x1.shape[:-2]
901915

tests/tensor/rewriting/test_linalg.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,19 @@ def test_slogdet_kronecker_rewrite():
828828
)
829829

830830

831+
def count_kron_ops(fgraph):
832+
return sum(
833+
[
834+
isinstance(node.op, KroneckerProduct)
835+
or (
836+
isinstance(node.op, Blockwise)
837+
and isinstance(node.op.core_op, KroneckerProduct)
838+
)
839+
for node in fgraph.apply_nodes
840+
]
841+
)
842+
843+
831844
@pytest.mark.parametrize("add_batch", [True, False], ids=["batched", "not_batched"])
832845
@pytest.mark.parametrize("b_ndim", [1, 2], ids=["b_ndim_1", "b_ndim_2"])
833846
@pytest.mark.parametrize(
@@ -858,27 +871,13 @@ def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs):
858871

859872
x = solve_op(C, y, **solve_kwargs, b_ndim=b_ndim)
860873

861-
def count_kron_ops(fn):
862-
return sum(
863-
[
864-
isinstance(node.op, KroneckerProduct)
865-
or (
866-
isinstance(node.op, Blockwise)
867-
and isinstance(node.op.core_op, KroneckerProduct)
868-
)
869-
for node in fn.maker.fgraph.apply_nodes
870-
]
871-
)
872-
873874
fn_expected = pytensor.function(
874875
[A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve")
875876
)
876-
assert count_kron_ops(fn_expected) == 1
877+
assert count_kron_ops(fn_expected.maker.fgraph) == 1
877878

878879
fn = pytensor.function([A, B, y], x)
879-
assert (
880-
count_kron_ops(fn) == 0
881-
), "Rewrite did not apply, KroneckerProduct found in the graph"
880+
assert count_kron_ops(fn.maker.fgraph) == 0
882881

883882
rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
884883
a_val = rng.normal(size=a_shape)
@@ -924,6 +923,30 @@ def count_kron_ops(fn):
924923
)
925924

926925

926+
def test_rewrite_solve_kron_to_solve_not_applied():
927+
# Check that the rewrite is not applied when the component matrices to the kron are static and not square
928+
A = pt.tensor("A", shape=(3, 2))
929+
B = pt.tensor("B", shape=(2, 3))
930+
C = pt.linalg.kron(A, B)
931+
932+
y = pt.vector("y", shape=(6,))
933+
x = pt.linalg.solve(C, y)
934+
935+
fn = pytensor.function([A, B, y], x)
936+
937+
assert count_kron_ops(fn.maker.fgraph) == 1
938+
939+
# If shapes are static, it should always be applied
940+
A = pt.tensor("A", shape=(3, None, None))
941+
B = pt.tensor("B", shape=(3, None, None))
942+
C = pt.linalg.kron(A, B)
943+
y = pt.tensor("y", shape=(None,))
944+
x = pt.linalg.solve(C, y)
945+
fn = pytensor.function([A, B, y], x)
946+
947+
assert count_kron_ops(fn.maker.fgraph) == 0
948+
949+
927950
@pytest.mark.parametrize(
928951
"a_shape, b_shape",
929952
[((5, 5), (5, 5)), ((50, 50), (50, 50)), ((100, 100), (100, 100))],

0 commit comments

Comments
 (0)