@@ -4749,15 +4749,21 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
4749
4749
4750
4750
4751
4751
@pytest .mark .parametrize ("left_multiply" , [True , False ], ids = ["left" , "right" ])
4752
- def test_local_block_diag_dot_to_dot_block_diag (left_multiply ):
4752
+ @pytest .mark .parametrize (
4753
+ "batch_left" , [True , False ], ids = ["batched_left" , "unbatched_left" ]
4754
+ )
4755
+ @pytest .mark .parametrize (
4756
+ "batch_right" , [True , False ], ids = ["batched_right" , "unbatched_right" ]
4757
+ )
4758
+ def test_local_block_diag_dot_to_dot_block_diag (left_multiply , batch_left , batch_right ):
4753
4759
"""
4754
4760
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
4755
4761
"""
4756
4762
a = tensor ("a" , shape = (4 , 2 ))
4757
- b = tensor ("b" , shape = (2 , 4 ))
4763
+ b = tensor ("b" , shape = (2 , 4 ) if not batch_left else ( 3 , 2 , 4 ) )
4758
4764
c = tensor ("c" , shape = (4 , 4 ))
4759
4765
d = tensor ("d" , shape = (10 , 10 ))
4760
- e = tensor ("e" , shape = (10 , 10 ))
4766
+ e = tensor ("e" , shape = (10 , 10 ) if not batch_right else ( 3 , 1 , 10 , 10 ) )
4761
4767
4762
4768
x = pt .linalg .block_diag (a , b , c )
4763
4769
@@ -4767,30 +4773,38 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
4767
4773
else :
4768
4774
out = [d @ x , e @ x ]
4769
4775
4770
- fn = pytensor .function ([a , b , c , d , e ], out , mode = rewrite_mode )
4776
+ with config .change_flags (optimizer_verbose = True ):
4777
+ fn = pytensor .function ([a , b , c , d , e ], out , mode = rewrite_mode )
4778
+
4771
4779
assert not any (
4772
4780
isinstance (node .op , BlockDiagonal ) for node in fn .maker .fgraph .toposort ()
4773
4781
)
4774
4782
4775
4783
fn_expected = pytensor .function (
4776
4784
[a , b , c , d , e ],
4777
4785
out ,
4778
- mode = rewrite_mode . excluding ( "local_block_diag_dot_to_dot_block_diag" ),
4786
+ mode = Mode ( linker = "py" , optimizer = None ),
4779
4787
)
4780
4788
4789
+ # TODO: Count Dots
4790
+
4781
4791
rng = np .random .default_rng ()
4782
4792
a_val = rng .normal (size = a .type .shape ).astype (a .type .dtype )
4783
4793
b_val = rng .normal (size = b .type .shape ).astype (b .type .dtype )
4784
4794
c_val = rng .normal (size = c .type .shape ).astype (c .type .dtype )
4785
4795
d_val = rng .normal (size = d .type .shape ).astype (d .type .dtype )
4786
4796
e_val = rng .normal (size = e .type .shape ).astype (e .type .dtype )
4787
4797
4788
- np .testing .assert_allclose (
4789
- fn (a_val , b_val , c_val , d_val , e_val ),
4790
- fn_expected (a_val , b_val , c_val , d_val , e_val ),
4791
- atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4792
- rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4793
- )
4798
+ rewrite_outs = fn (a_val , b_val , c_val , d_val , e_val )
4799
+ expected_outs = fn_expected (a_val , b_val , c_val , d_val , e_val )
4800
+
4801
+ for out , expected in zip (rewrite_outs , expected_outs ):
4802
+ np .testing .assert_allclose (
4803
+ out ,
4804
+ expected ,
4805
+ atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4806
+ rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4807
+ )
4794
4808
4795
4809
4796
4810
@pytest .mark .parametrize ("rewrite" , [True , False ], ids = ["rewrite" , "no_rewrite" ])
0 commit comments