Skip to content

Commit 7406beb

Browse files
committed
Refactor TestMixtureVsLatent
The two tests relied on implicit behavior of V3, where the dimensionality of the weights implied the support dimension of mixture distribution. This, however, led to inconsistent behavior between the random method and the logp, as the latter did not enforce this assumption, and did not distinguish if values were mixed across the implied support dimension. In this refactoring, the support dimensionality of the component variables determines the dimensionality of the mixture distribution, regardless of the weights. This leads to consistent behavior between the random and logp methods as asserted by the new checks. Future work will explore allowing the user to specify an artificial support dimensionality that is higher than the one implied by the component distributions, but this is for now not possible.
1 parent 524025b commit 7406beb

File tree

1 file changed

+77
-67
lines changed

1 file changed

+77
-67
lines changed

pymc/tests/test_mixture.py

Lines changed: 77 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -767,60 +767,19 @@ def ref_rand(size, w, mu, sigma):
767767
)
768768

769769

770-
@pytest.mark.xfail(reason="NormalMixture not refactored yet")
771770
class TestMixtureVsLatent(SeededTest):
772-
def setup_method(self, *args, **kwargs):
773-
super().setup_method(*args, **kwargs)
774-
self.nd = 3
775-
self.npop = 3
776-
self.mus = at.as_tensor_variable(
777-
np.tile(
778-
np.reshape(
779-
np.arange(self.npop),
780-
(
781-
1,
782-
-1,
783-
),
784-
),
785-
(
786-
self.nd,
787-
1,
788-
),
789-
)
790-
)
771+
"""This class contains tests that compare a marginal Mixture with a latent indexed Mixture"""
791772

792-
def test_1d_w(self):
793-
nd = self.nd
794-
npop = self.npop
795-
mus = self.mus
796-
size = 100
797-
with Model() as model:
798-
m = NormalMixture(
799-
"m", w=np.ones(npop) / npop, mu=mus, sigma=1e-5, comp_shape=(nd, npop), shape=nd
800-
)
801-
z = Categorical("z", p=np.ones(npop) / npop)
802-
latent_m = Normal("latent_m", mu=mus[..., z], sigma=1e-5, shape=nd)
773+
def test_scalar_components(self):
774+
nd = 3
775+
npop = 4
776+
# [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]
777+
mus = at.constant(np.full((nd, npop), np.arange(npop)))
803778

804-
m_val = m.random(size=size)
805-
latent_m_val = latent_m.random(size=size)
806-
assert m_val.shape == latent_m_val.shape
807-
# Test that each element in axis = -1 comes from the same mixture
808-
# component
809-
assert all(np.all(np.diff(m_val) < 1e-3, axis=-1))
810-
assert all(np.all(np.diff(latent_m_val) < 1e-3, axis=-1))
811-
812-
self.samples_from_same_distribution(m_val, latent_m_val)
813-
self.logp_matches(m, latent_m, z, npop, model=model)
814-
815-
def test_2d_w(self):
816-
nd = self.nd
817-
npop = self.npop
818-
mus = self.mus
819-
size = 100
820-
with Model() as model:
779+
with Model(rng_seeder=self.get_random_state()) as model:
821780
m = NormalMixture(
822781
"m",
823-
w=np.ones((nd, npop)) / npop,
782+
w=np.ones(npop) / npop,
824783
mu=mus,
825784
sigma=1e-5,
826785
comp_shape=(nd, npop),
@@ -830,15 +789,55 @@ def test_2d_w(self):
830789
mu = at.as_tensor_variable([mus[i, z[i]] for i in range(nd)])
831790
latent_m = Normal("latent_m", mu=mu, sigma=1e-5, shape=nd)
832791

833-
m_val = m.random(size=size)
834-
latent_m_val = latent_m.random(size=size)
792+
size = 100
793+
m_val = draw(m, draws=size)
794+
latent_m_val = draw(latent_m, draws=size)
795+
835796
assert m_val.shape == latent_m_val.shape
836797
# Test that each element in axis = -1 can come from independent
837798
# components
838799
assert not all(np.all(np.diff(m_val) < 1e-3, axis=-1))
839800
assert not all(np.all(np.diff(latent_m_val) < 1e-3, axis=-1))
801+
self.samples_from_same_distribution(m_val, latent_m_val)
802+
803+
# Check that logp is the same whether elements of the last axis are mixed or not
804+
logp_fn = model.compile_logp(vars=[m])
805+
assert np.isclose(logp_fn({"m": [0, 0, 0]}), logp_fn({"m": [0, 1, 2]}))
806+
self.logp_matches(m, latent_m, z, npop, model=model)
807+
808+
def test_vector_components(self):
809+
nd = 3
810+
npop = 4
811+
# [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]
812+
mus = at.constant(np.full((nd, npop), np.arange(npop)))
813+
814+
with Model(rng_seeder=self.get_random_state()) as model:
815+
m = Mixture(
816+
"m",
817+
w=np.ones(npop) / npop,
818+
# MvNormal distribution with squared sigma diagonal covariance should
819+
# be equal to vector of Normals from latent_m
820+
comp_dists=[MvNormal.dist(mus[:, i], np.eye(nd) * 1e-5**2) for i in range(npop)],
821+
)
822+
z = Categorical("z", p=np.ones(npop) / npop)
823+
latent_m = Normal("latent_m", mu=mus[..., z], sigma=1e-5, shape=nd)
840824

825+
size = 100
826+
m_val = draw(m, draws=size)
827+
latent_m_val = draw(latent_m, draws=size)
828+
assert m_val.shape == latent_m_val.shape
829+
# Test that each element in axis = -1 comes from the same mixture
830+
# component
831+
assert np.all(np.diff(m_val) < 1e-3)
832+
assert np.all(np.diff(latent_m_val) < 1e-3)
833+
# TODO: The following statistical test appears to be more flaky than expected
834+
# even though the distributions should be the same. Seeding should make it
835+
# stable but might be worth investigating further
841836
self.samples_from_same_distribution(m_val, latent_m_val)
837+
838+
# Check that mixing of values in the last axis leads to smaller logp
839+
logp_fn = model.compile_logp(vars=[m])
840+
assert logp_fn({"m": [0, 0, 0]}) > logp_fn({"m": [0, 1, 0]}) > logp_fn({"m": [0, 1, 2]})
842841
self.logp_matches(m, latent_m, z, npop, model=model)
843842

844843
def samples_from_same_distribution(self, *args):
@@ -848,31 +847,42 @@ def samples_from_same_distribution(self, *args):
848847
_, p_correlation = st.ks_2samp(
849848
*(np.array([np.corrcoef(ss) for ss in s]).flatten() for s in args)
850849
)
850+
# This has a success rate of 10% (0.95**2), even if the distributions are the same
851851
assert p_marginal >= 0.05 and p_correlation >= 0.05
852852

853853
def logp_matches(self, mixture, latent_mix, z, npop, model):
854+
def loose_logp(model, vars):
855+
"""Return logp function that accepts dictionary with unused variables as input"""
856+
return model.compile_fn(
857+
model.logpt(vars=vars, sum=False),
858+
inputs=model.value_vars,
859+
on_unused_input="ignore",
860+
)
861+
854862
if aesara.config.floatX == "float32":
855863
rtol = 1e-4
856864
else:
857865
rtol = 1e-7
858866
test_point = model.compute_initial_point()
859-
test_point["latent_m"] = test_point["m"]
860-
mix_logp = mixture.logp(test_point)
861-
logps = []
867+
test_point["m"] = test_point["latent_m"]
868+
869+
mix_logp = loose_logp(model, mixture)(test_point)[0]
870+
871+
z_shape = z.shape.eval()
872+
latent_mix_components_logps = []
862873
for component in range(npop):
863-
test_point["z"] = component * np.ones(z.distribution.shape)
864-
# Count the number of axes that should be broadcasted from z to
865-
# modify the logp
866-
sh1 = test_point["z"].shape
867-
sh2 = test_point["latent_m"].shape
868-
if len(sh1) > len(sh2):
869-
sh2 = (1,) * (len(sh1) - len(sh2)) + sh2
870-
elif len(sh2) > len(sh1):
871-
sh1 = (1,) * (len(sh2) - len(sh1)) + sh1
872-
reps = np.prod([s2 if s1 != s2 else 1 for s1, s2 in zip(sh1, sh2)])
873-
z_logp = z.logp(test_point) * reps
874-
logps.append(z_logp + latent_mix.logp(test_point))
875-
latent_mix_logp = logsumexp(np.array(logps), axis=0)
874+
test_point["z"] = np.full(z_shape, component)
875+
z_logp = loose_logp(model, z)(test_point)[0]
876+
latent_mix_component_logp = loose_logp(model, latent_mix)(test_point)[0]
877+
# If the mixture ndim_supp is a vector, the logp should be summed within
878+
# components, as its items are not independent
879+
if mix_logp.ndim == 0:
880+
latent_mix_component_logp = latent_mix_component_logp.sum()
881+
latent_mix_components_logps.append(z_logp + latent_mix_component_logp)
882+
latent_mix_logp = logsumexp(np.array(latent_mix_components_logps), axis=0)
883+
if mix_logp.ndim == 0:
884+
latent_mix_logp = latent_mix_logp.sum()
885+
876886
assert_allclose(mix_logp, latent_mix_logp, rtol=rtol)
877887

878888

0 commit comments

Comments
 (0)