@@ -767,60 +767,19 @@ def ref_rand(size, w, mu, sigma):
767
767
)
768
768
769
769
770
- @pytest .mark .xfail (reason = "NormalMixture not refactored yet" )
771
770
class TestMixtureVsLatent (SeededTest ):
772
- def setup_method (self , * args , ** kwargs ):
773
- super ().setup_method (* args , ** kwargs )
774
- self .nd = 3
775
- self .npop = 3
776
- self .mus = at .as_tensor_variable (
777
- np .tile (
778
- np .reshape (
779
- np .arange (self .npop ),
780
- (
781
- 1 ,
782
- - 1 ,
783
- ),
784
- ),
785
- (
786
- self .nd ,
787
- 1 ,
788
- ),
789
- )
790
- )
771
+ """This class contains tests that compare a marginal Mixture with a latent indexed Mixture"""
791
772
792
- def test_1d_w (self ):
793
- nd = self .nd
794
- npop = self .npop
795
- mus = self .mus
796
- size = 100
797
- with Model () as model :
798
- m = NormalMixture (
799
- "m" , w = np .ones (npop ) / npop , mu = mus , sigma = 1e-5 , comp_shape = (nd , npop ), shape = nd
800
- )
801
- z = Categorical ("z" , p = np .ones (npop ) / npop )
802
- latent_m = Normal ("latent_m" , mu = mus [..., z ], sigma = 1e-5 , shape = nd )
773
+ def test_scalar_components (self ):
774
+ nd = 3
775
+ npop = 4
776
+ # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]
777
+ mus = at .constant (np .full ((nd , npop ), np .arange (npop )))
803
778
804
- m_val = m .random (size = size )
805
- latent_m_val = latent_m .random (size = size )
806
- assert m_val .shape == latent_m_val .shape
807
- # Test that each element in axis = -1 comes from the same mixture
808
- # component
809
- assert all (np .all (np .diff (m_val ) < 1e-3 , axis = - 1 ))
810
- assert all (np .all (np .diff (latent_m_val ) < 1e-3 , axis = - 1 ))
811
-
812
- self .samples_from_same_distribution (m_val , latent_m_val )
813
- self .logp_matches (m , latent_m , z , npop , model = model )
814
-
815
- def test_2d_w (self ):
816
- nd = self .nd
817
- npop = self .npop
818
- mus = self .mus
819
- size = 100
820
- with Model () as model :
779
+ with Model (rng_seeder = self .get_random_state ()) as model :
821
780
m = NormalMixture (
822
781
"m" ,
823
- w = np .ones (( nd , npop ) ) / npop ,
782
+ w = np .ones (npop ) / npop ,
824
783
mu = mus ,
825
784
sigma = 1e-5 ,
826
785
comp_shape = (nd , npop ),
@@ -830,15 +789,55 @@ def test_2d_w(self):
830
789
mu = at .as_tensor_variable ([mus [i , z [i ]] for i in range (nd )])
831
790
latent_m = Normal ("latent_m" , mu = mu , sigma = 1e-5 , shape = nd )
832
791
833
- m_val = m .random (size = size )
834
- latent_m_val = latent_m .random (size = size )
792
+ size = 100
793
+ m_val = draw (m , draws = size )
794
+ latent_m_val = draw (latent_m , draws = size )
795
+
835
796
assert m_val .shape == latent_m_val .shape
836
797
# Test that each element in axis = -1 can come from independent
837
798
# components
838
799
assert not all (np .all (np .diff (m_val ) < 1e-3 , axis = - 1 ))
839
800
assert not all (np .all (np .diff (latent_m_val ) < 1e-3 , axis = - 1 ))
801
+ self .samples_from_same_distribution (m_val , latent_m_val )
802
+
803
+ # Check that logp is the same whether elements of the last axis are mixed or not
804
+ logp_fn = model .compile_logp (vars = [m ])
805
+ assert np .isclose (logp_fn ({"m" : [0 , 0 , 0 ]}), logp_fn ({"m" : [0 , 1 , 2 ]}))
806
+ self .logp_matches (m , latent_m , z , npop , model = model )
807
+
808
+ def test_vector_components (self ):
809
+ nd = 3
810
+ npop = 4
811
+ # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]
812
+ mus = at .constant (np .full ((nd , npop ), np .arange (npop )))
813
+
814
+ with Model (rng_seeder = self .get_random_state ()) as model :
815
+ m = Mixture (
816
+ "m" ,
817
+ w = np .ones (npop ) / npop ,
818
+ # MvNormal distribution with squared sigma diagonal covariance should
819
+ # be equal to vector of Normals from latent_m
820
+ comp_dists = [MvNormal .dist (mus [:, i ], np .eye (nd ) * 1e-5 ** 2 ) for i in range (npop )],
821
+ )
822
+ z = Categorical ("z" , p = np .ones (npop ) / npop )
823
+ latent_m = Normal ("latent_m" , mu = mus [..., z ], sigma = 1e-5 , shape = nd )
840
824
825
+ size = 100
826
+ m_val = draw (m , draws = size )
827
+ latent_m_val = draw (latent_m , draws = size )
828
+ assert m_val .shape == latent_m_val .shape
829
+ # Test that each element in axis = -1 comes from the same mixture
830
+ # component
831
+ assert np .all (np .diff (m_val ) < 1e-3 )
832
+ assert np .all (np .diff (latent_m_val ) < 1e-3 )
833
+ # TODO: The following statistical test appears to be more flaky than expected
834
+ # even though the distributions should be the same. Seeding should make it
835
+ # stable but might be worth investigating further
841
836
self .samples_from_same_distribution (m_val , latent_m_val )
837
+
838
+ # Check that mixing of values in the last axis leads to smaller logp
839
+ logp_fn = model .compile_logp (vars = [m ])
840
+ assert logp_fn ({"m" : [0 , 0 , 0 ]}) > logp_fn ({"m" : [0 , 1 , 0 ]}) > logp_fn ({"m" : [0 , 1 , 2 ]})
842
841
self .logp_matches (m , latent_m , z , npop , model = model )
843
842
844
843
def samples_from_same_distribution (self , * args ):
@@ -848,31 +847,42 @@ def samples_from_same_distribution(self, *args):
848
847
_ , p_correlation = st .ks_2samp (
849
848
* (np .array ([np .corrcoef (ss ) for ss in s ]).flatten () for s in args )
850
849
)
850
+ # This has a success rate of 10% (0.95**2), even if the distributions are the same
851
851
assert p_marginal >= 0.05 and p_correlation >= 0.05
852
852
853
853
def logp_matches (self , mixture , latent_mix , z , npop , model ):
854
+ def loose_logp (model , vars ):
855
+ """Return logp function that accepts dictionary with unused variables as input"""
856
+ return model .compile_fn (
857
+ model .logpt (vars = vars , sum = False ),
858
+ inputs = model .value_vars ,
859
+ on_unused_input = "ignore" ,
860
+ )
861
+
854
862
if aesara .config .floatX == "float32" :
855
863
rtol = 1e-4
856
864
else :
857
865
rtol = 1e-7
858
866
test_point = model .compute_initial_point ()
859
- test_point ["latent_m" ] = test_point ["m" ]
860
- mix_logp = mixture .logp (test_point )
861
- logps = []
867
+ test_point ["m" ] = test_point ["latent_m" ]
868
+
869
+ mix_logp = loose_logp (model , mixture )(test_point )[0 ]
870
+
871
+ z_shape = z .shape .eval ()
872
+ latent_mix_components_logps = []
862
873
for component in range (npop ):
863
- test_point ["z" ] = component * np .ones (z .distribution .shape )
864
- # Count the number of axes that should be broadcasted from z to
865
- # modify the logp
866
- sh1 = test_point ["z" ].shape
867
- sh2 = test_point ["latent_m" ].shape
868
- if len (sh1 ) > len (sh2 ):
869
- sh2 = (1 ,) * (len (sh1 ) - len (sh2 )) + sh2
870
- elif len (sh2 ) > len (sh1 ):
871
- sh1 = (1 ,) * (len (sh2 ) - len (sh1 )) + sh1
872
- reps = np .prod ([s2 if s1 != s2 else 1 for s1 , s2 in zip (sh1 , sh2 )])
873
- z_logp = z .logp (test_point ) * reps
874
- logps .append (z_logp + latent_mix .logp (test_point ))
875
- latent_mix_logp = logsumexp (np .array (logps ), axis = 0 )
874
+ test_point ["z" ] = np .full (z_shape , component )
875
+ z_logp = loose_logp (model , z )(test_point )[0 ]
876
+ latent_mix_component_logp = loose_logp (model , latent_mix )(test_point )[0 ]
877
+ # If the mixture ndim_supp is a vector, the logp should be summed within
878
+ # components, as its items are not independent
879
+ if mix_logp .ndim == 0 :
880
+ latent_mix_component_logp = latent_mix_component_logp .sum ()
881
+ latent_mix_components_logps .append (z_logp + latent_mix_component_logp )
882
+ latent_mix_logp = logsumexp (np .array (latent_mix_components_logps ), axis = 0 )
883
+ if mix_logp .ndim == 0 :
884
+ latent_mix_logp = latent_mix_logp .sum ()
885
+
876
886
assert_allclose (mix_logp , latent_mix_logp , rtol = rtol )
877
887
878
888
0 commit comments