@@ -828,6 +828,19 @@ def test_slogdet_kronecker_rewrite():
828828 )
829829
830830
831+ def count_kron_ops (fgraph ):
832+ return sum (
833+ [
834+ isinstance (node .op , KroneckerProduct )
835+ or (
836+ isinstance (node .op , Blockwise )
837+ and isinstance (node .op .core_op , KroneckerProduct )
838+ )
839+ for node in fgraph .apply_nodes
840+ ]
841+ )
842+
843+
831844@pytest .mark .parametrize ("add_batch" , [True , False ], ids = ["batched" , "not_batched" ])
832845@pytest .mark .parametrize ("b_ndim" , [1 , 2 ], ids = ["b_ndim_1" , "b_ndim_2" ])
833846@pytest .mark .parametrize (
@@ -858,27 +871,13 @@ def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs):
858871
859872 x = solve_op (C , y , ** solve_kwargs , b_ndim = b_ndim )
860873
861- def count_kron_ops (fn ):
862- return sum (
863- [
864- isinstance (node .op , KroneckerProduct )
865- or (
866- isinstance (node .op , Blockwise )
867- and isinstance (node .op .core_op , KroneckerProduct )
868- )
869- for node in fn .maker .fgraph .apply_nodes
870- ]
871- )
872-
873874 fn_expected = pytensor .function (
874875 [A , B , y ], x , mode = get_default_mode ().excluding ("rewrite_solve_kron_to_solve" )
875876 )
876- assert count_kron_ops (fn_expected ) == 1
877+ assert count_kron_ops (fn_expected . maker . fgraph ) == 1
877878
878879 fn = pytensor .function ([A , B , y ], x )
879- assert (
880- count_kron_ops (fn ) == 0
881- ), "Rewrite did not apply, KroneckerProduct found in the graph"
880+ assert count_kron_ops (fn .maker .fgraph ) == 0
882881
883882 rng = np .random .default_rng (sum (map (ord , "Go away Kron!" )))
884883 a_val = rng .normal (size = a_shape )
@@ -924,6 +923,30 @@ def count_kron_ops(fn):
924923 )
925924
926925
926+ def test_rewrite_solve_kron_to_solve_not_applied ():
927+ # Check that the rewrite is not applied when the component matrices to the kron are static and not square
928+ A = pt .tensor ("A" , shape = (3 , 2 ))
929+ B = pt .tensor ("B" , shape = (2 , 3 ))
930+ C = pt .linalg .kron (A , B )
931+
932+ y = pt .vector ("y" , shape = (6 ,))
933+ x = pt .linalg .solve (C , y )
934+
935+ fn = pytensor .function ([A , B , y ], x )
936+
937+ assert count_kron_ops (fn .maker .fgraph ) == 1
938+
939+ # If shapes are static, it should always be applied
940+ A = pt .tensor ("A" , shape = (3 , None , None ))
941+ B = pt .tensor ("B" , shape = (3 , None , None ))
942+ C = pt .linalg .kron (A , B )
943+ y = pt .tensor ("y" , shape = (None ,))
944+ x = pt .linalg .solve (C , y )
945+ fn = pytensor .function ([A , B , y ], x )
946+
947+ assert count_kron_ops (fn .maker .fgraph ) == 0
948+
949+
927950@pytest .mark .parametrize (
928951 "a_shape, b_shape" ,
929952 [((5 , 5 ), (5 , 5 )), ((50 , 50 ), (50 , 50 )), ((100 , 100 ), (100 , 100 ))],
0 commit comments