Skip to content

Commit 34f4679

Browse files
committed
Remove MixtureSameFamily
Behavior is now implemented in Mixture
1 parent 7406beb commit 34f4679

File tree

5 files changed

+22
-262
lines changed

5 files changed

+22
-262
lines changed

docs/source/api/distributions/mixture.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ Mixture
88

99
Mixture
1010
NormalMixture
11-
MixtureSameFamily

pymc/distributions/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
NoDistribution,
8484
SymbolicDistribution,
8585
)
86-
from pymc.distributions.mixture import Mixture, MixtureSameFamily, NormalMixture
86+
from pymc.distributions.mixture import Mixture, NormalMixture
8787
from pymc.distributions.multivariate import (
8888
CAR,
8989
Dirichlet,
@@ -180,7 +180,6 @@
180180
"SkewNormal",
181181
"Mixture",
182182
"NormalMixture",
183-
"MixtureSameFamily",
184183
"Triangular",
185184
"DiscreteWeibull",
186185
"Gumbel",

pymc/distributions/mixture.py

Lines changed: 2 additions & 235 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,15 @@
2323
from aesara.tensor import TensorVariable
2424
from aesara.tensor.random.op import RandomVariable
2525

26-
from pymc.aesaraf import change_rv_size, take_along_axis
26+
from pymc.aesaraf import change_rv_size
2727
from pymc.distributions.continuous import Normal, get_tau_sigma
2828
from pymc.distributions.dist_math import check_parameters
2929
from pymc.distributions.distribution import Discrete, Distribution, SymbolicDistribution
3030
from pymc.distributions.logprob import logp
3131
from pymc.distributions.shape_utils import to_tuple
32-
from pymc.math import logsumexp
3332
from pymc.util import check_dist_not_registered
3433

35-
__all__ = ["Mixture", "NormalMixture", "MixtureSameFamily"]
34+
__all__ = ["Mixture", "NormalMixture"]
3635

3736

3837
def all_discrete(comp_dists):
@@ -468,235 +467,3 @@ def dist(cls, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), **kwargs):
468467
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
469468

470469
return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)
471-
472-
473-
class MixtureSameFamily(Distribution):
474-
R"""
475-
Mixture Same Family log-likelihood
476-
This distribution handles mixtures of multivariate distributions in a vectorized
477-
manner. It is used over Mixture distribution when the mixture components are not
478-
present on the last axis of components' distribution.
479-
480-
.. math::f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)\textrm{ Along mixture\_axis}
481-
482-
======== ============================================
483-
Support :math:`\textrm{support}(f)`
484-
Mean :math:`w\mu`
485-
======== ============================================
486-
487-
Parameters
488-
----------
489-
w: array of floats
490-
w >= 0 and w <= 1
491-
the mixture weights
492-
comp_dists: PyMC distribution (e.g. `pm.Multinomial.dist(...)`)
493-
The `comp_dists` can be scalar or multidimensional distribution.
494-
Assuming its shape to be - (i_0, ..., i_n, mixture_axis, i_n+1, ..., i_N),
495-
the `mixture_axis` is consumed resulting in the shape of mixture as -
496-
(i_0, ..., i_n, i_n+1, ..., i_N).
497-
mixture_axis: int, default = -1
498-
Axis representing the mixture components to be reduced in the mixture.
499-
500-
Notes
501-
-----
502-
The default behaviour resembles Mixture distribution wherein the last axis of component
503-
distribution is reduced.
504-
"""
505-
506-
def __init__(self, w, comp_dists, mixture_axis=-1, *args, **kwargs):
507-
self.w = at.as_tensor_variable(w)
508-
if not isinstance(comp_dists, Distribution):
509-
raise TypeError(
510-
"The MixtureSameFamily distribution only accepts Distribution "
511-
f"instances as its components. Got {type(comp_dists)} instead."
512-
)
513-
self.comp_dists = comp_dists
514-
if mixture_axis < 0:
515-
mixture_axis = len(comp_dists.shape) + mixture_axis
516-
if mixture_axis < 0:
517-
raise ValueError(
518-
"`mixture_axis` is supposed to be in shape of components' distribution. "
519-
f"Got {mixture_axis + len(comp_dists.shape)} axis instead out of the bounds."
520-
)
521-
comp_shape = to_tuple(comp_dists.shape)
522-
self.shape = comp_shape[:mixture_axis] + comp_shape[mixture_axis + 1 :]
523-
self.mixture_axis = mixture_axis
524-
kwargs.setdefault("dtype", self.comp_dists.dtype)
525-
526-
# Compute the mode so we don't always have to pass a initval
527-
defaults = kwargs.pop("defaults", [])
528-
event_shape = self.comp_dists.shape[mixture_axis + 1 :]
529-
_w = at.shape_padleft(
530-
at.shape_padright(w, len(event_shape)),
531-
len(self.comp_dists.shape) - w.ndim - len(event_shape),
532-
)
533-
mode = take_along_axis(
534-
self.comp_dists.mode,
535-
at.argmax(_w, keepdims=True),
536-
axis=mixture_axis,
537-
)
538-
self.mode = mode[(..., 0) + (slice(None),) * len(event_shape)]
539-
540-
if not all_discrete(comp_dists):
541-
mean = at.as_tensor_variable(self.comp_dists.mean)
542-
self.mean = (_w * mean).sum(axis=mixture_axis)
543-
if "mean" not in defaults:
544-
defaults.append("mean")
545-
defaults.append("mode")
546-
547-
super().__init__(defaults=defaults, *args, **kwargs)
548-
549-
def logp(self, value):
550-
"""
551-
Calculate log-probability of defined ``MixtureSameFamily`` distribution at specified value.
552-
553-
Parameters
554-
----------
555-
value : numeric
556-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
557-
values are desired the values must be provided in a numpy array or Aesara tensor
558-
559-
Returns
560-
-------
561-
TensorVariable
562-
"""
563-
564-
comp_dists = self.comp_dists
565-
w = self.w
566-
mixture_axis = self.mixture_axis
567-
568-
event_shape = comp_dists.shape[mixture_axis + 1 :]
569-
570-
# To be able to broadcast the comp_dists.logp with w and value
571-
# We first have to pad the shape of w to the right with ones
572-
# so that it can broadcast with the event_shape.
573-
574-
w = at.shape_padright(w, len(event_shape))
575-
576-
# Second, we have to add the mixture_axis to the value tensor
577-
# To insert the mixture axis at the correct location, we use the
578-
# negative number index. This way, we can also handle situations
579-
# in which, value is an observed value with more batch dimensions
580-
# than the ones present in the comp_dists.
581-
comp_dists_ndim = len(comp_dists.shape)
582-
583-
value = at.shape_padaxis(value, axis=mixture_axis - comp_dists_ndim)
584-
585-
comp_logp = comp_dists.logp(value)
586-
return check_parameters(
587-
logsumexp(at.log(w) + comp_logp, axis=mixture_axis, keepdims=False),
588-
w >= 0,
589-
w <= 1,
590-
at.allclose(w.sum(axis=mixture_axis - comp_dists_ndim), 1),
591-
broadcast_conditions=False,
592-
)
593-
594-
def random(self, point=None, size=None):
595-
"""
596-
Draw random values from defined ``MixtureSameFamily`` distribution.
597-
598-
Parameters
599-
----------
600-
point : dict, optional
601-
Dict of variable values on which random values are to be
602-
conditioned (uses default point if not specified).
603-
size : int, optional
604-
Desired size of random sample (returns one sample if not
605-
specified).
606-
607-
Returns
608-
-------
609-
array
610-
"""
611-
# sample_shape = to_tuple(size)
612-
# mixture_axis = self.mixture_axis
613-
#
614-
# # First we draw values for the mixture component weights
615-
# (w,) = draw_values([self.w], point=point, size=size)
616-
#
617-
# # We now draw random choices from those weights.
618-
# # However, we have to ensure that the number of choices has the
619-
# # sample_shape present.
620-
# w_shape = w.shape
621-
# batch_shape = self.comp_dists.shape[: mixture_axis + 1]
622-
# param_shape = np.broadcast(np.empty(w_shape), np.empty(batch_shape)).shape
623-
# event_shape = self.comp_dists.shape[mixture_axis + 1 :]
624-
#
625-
# if np.asarray(self.shape).size != 0:
626-
# comp_dists_ndim = len(self.comp_dists.shape)
627-
#
628-
# # If event_shape of both comp_dists and supplied shape matches,
629-
# # broadcast only batch_shape
630-
# # else broadcast the entire given shape with batch_shape.
631-
# if list(self.shape[mixture_axis - comp_dists_ndim + 1 :]) == list(event_shape):
632-
# dist_shape = np.broadcast(
633-
# np.empty(self.shape[:mixture_axis]), np.empty(param_shape[:mixture_axis])
634-
# ).shape
635-
# else:
636-
# dist_shape = np.broadcast(
637-
# np.empty(self.shape), np.empty(param_shape[:mixture_axis])
638-
# ).shape
639-
# else:
640-
# dist_shape = param_shape[:mixture_axis]
641-
#
642-
# # Try to determine the size that must be used to get the mixture
643-
# # components (i.e. get random choices using w).
644-
# # 1. There must be size independent choices based on w.
645-
# # 2. There must also be independent draws for each non singleton axis
646-
# # of w.
647-
# # 3. There must also be independent draws for each dimension added by
648-
# # self.shape with respect to the w.ndim. These usually correspond to
649-
# # observed variables with batch shapes
650-
# wsh = (1,) * (len(dist_shape) - len(w_shape) + 1) + w_shape[:mixture_axis]
651-
# psh = (1,) * (len(dist_shape) - len(param_shape) + 1) + param_shape[:mixture_axis]
652-
# w_sample_size = []
653-
# # Loop through the dist_shape to get the conditions 2 and 3 first
654-
# for i in range(len(dist_shape)):
655-
# if dist_shape[i] != psh[i] and wsh[i] == 1:
656-
# # self.shape[i] is a non singleton dimension (usually caused by
657-
# # observed data)
658-
# sh = dist_shape[i]
659-
# else:
660-
# sh = wsh[i]
661-
# w_sample_size.append(sh)
662-
#
663-
# if sample_shape is not None and w_sample_size[: len(sample_shape)] != sample_shape:
664-
# w_sample_size = sample_shape + tuple(w_sample_size)
665-
#
666-
# choices = random_choice(p=w, size=w_sample_size)
667-
#
668-
# # We now draw samples from the mixture components random method
669-
# comp_samples = self.comp_dists.random(point=point, size=size)
670-
# if comp_samples.shape[: len(sample_shape)] != sample_shape:
671-
# comp_samples = np.broadcast_to(
672-
# comp_samples,
673-
# shape=sample_shape + comp_samples.shape,
674-
# )
675-
#
676-
# # At this point the shapes of the arrays involved are:
677-
# # comp_samples.shape = (sample_shape, batch_shape, mixture_axis, event_shape)
678-
# # choices.shape = (sample_shape, batch_shape)
679-
# #
680-
# # To be able to take the choices along the mixture_axis of the
681-
# # comp_samples, we have to add in dimensions to the right of the
682-
# # choices array.
683-
# # We also need to make sure that the batch_shapes of both the comp_samples
684-
# # and choices broadcast with each other.
685-
#
686-
# choices = np.reshape(choices, choices.shape + (1,) * (1 + len(event_shape)))
687-
#
688-
# choices, comp_samples = get_broadcastable_dist_samples([choices, comp_samples], size=size)
689-
#
690-
# # We now take the choices of the mixture components along the mixture_axis
691-
# # but we use the negative index representation to be able to handle the
692-
# # sample_shape
693-
# samples = np.take_along_axis(
694-
# comp_samples, choices, axis=mixture_axis - len(self.comp_dists.shape)
695-
# )
696-
#
697-
# # The `samples` array still has the `mixture_axis`, so we must remove it:
698-
# output = samples[(..., 0) + (slice(None),) * len(event_shape)]
699-
# return output
700-
701-
def _distr_parameters_for_repr(self):
702-
return []

pymc/tests/test_distributions_moments.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def test_all_distributions_have_moments():
100100

101101
# Distributions that have not been refactored for V4 yet
102102
not_implemented = {
103-
dist_module.mixture.MixtureSameFamily,
104103
dist_module.timeseries.AR,
105104
dist_module.timeseries.AR1,
106105
dist_module.timeseries.GARCH11,

pymc/tests/test_mixture.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
LKJCholeskyCov,
3535
LogNormal,
3636
Mixture,
37-
MixtureSameFamily,
3837
Multinomial,
3938
MvNormal,
4039
Normal,
@@ -886,8 +885,12 @@ def loose_logp(model, vars):
886885
assert_allclose(mix_logp, latent_mix_logp, rtol=rtol)
887886

888887

889-
@pytest.mark.xfail(reason="MixtureSameFamily not refactored yet")
890888
class TestMixtureSameFamily(SeededTest):
889+
"""Tests that used to belong to deprecated `TestMixtureSameFamily`.
890+
891+
The functionality is now expected to be provided by `Mixture`
892+
"""
893+
891894
@classmethod
892895
def setup_class(cls):
893896
super().setup_class()
@@ -903,36 +906,33 @@ def test_with_multinomial(self, batch_shape):
903906
mixture_axis = len(batch_shape)
904907
with Model() as model:
905908
comp_dists = Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3))
906-
mixture = MixtureSameFamily(
909+
mixture = Mixture(
907910
"mixture",
908911
w=w,
909912
comp_dists=comp_dists,
910-
mixture_axis=mixture_axis,
911913
shape=(*batch_shape, 3),
912914
)
913-
prior = sample_prior_predictive(samples=self.n_samples)
915+
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
914916

915917
assert prior["mixture"].shape == (self.n_samples, *batch_shape, 3)
916-
assert mixture.random(size=self.size).shape == (self.size, *batch_shape, 3)
918+
assert draw(mixture, draws=self.size).shape == (self.size, *batch_shape, 3)
917919

918920
if aesara.config.floatX == "float32":
919921
rtol = 1e-4
920922
else:
921923
rtol = 1e-7
922924

923925
initial_point = model.compute_initial_point()
924-
comp_logp = comp_dists.logp(initial_point["mixture"].reshape(*batch_shape, 1, 3))
926+
comp_logp = logp(comp_dists, initial_point["mixture"].reshape(*batch_shape, 1, 3))
925927
log_sum_exp = logsumexp(
926-
comp_logp.eval() + np.log(w)[..., None], axis=mixture_axis, keepdims=True
928+
comp_logp.eval() + np.log(w), axis=mixture_axis, keepdims=True
927929
).sum()
928930
assert_allclose(
929-
model.logp(initial_point),
931+
model.compile_logp()(initial_point),
930932
log_sum_exp,
931933
rtol,
932934
)
933935

934-
# TODO: Handle case when `batch_shape` == `sample_shape`.
935-
# See https://github.com/pymc-devs/pymc/issues/4185 for details.
936936
def test_with_mvnormal(self):
937937
# 10 batch, 3-variate Gaussian
938938
mu = np.random.randn(self.mixture_comps, 3)
@@ -943,26 +943,22 @@ def test_with_mvnormal(self):
943943

944944
with Model() as model:
945945
comp_dists = MvNormal.dist(mu=mu, chol=chol, shape=(self.mixture_comps, 3))
946-
mixture = MixtureSameFamily(
947-
"mixture", w=w, comp_dists=comp_dists, mixture_axis=0, shape=(3,)
948-
)
949-
prior = sample_prior_predictive(samples=self.n_samples)
946+
mixture = Mixture("mixture", w=w, comp_dists=comp_dists, shape=(3,))
947+
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
950948

951949
assert prior["mixture"].shape == (self.n_samples, 3)
952-
assert mixture.random(size=self.size).shape == (self.size, 3)
950+
assert draw(mixture, draws=self.size).shape == (self.size, 3)
953951

954952
if aesara.config.floatX == "float32":
955953
rtol = 1e-4
956954
else:
957955
rtol = 1e-7
958956

959957
initial_point = model.compute_initial_point()
960-
comp_logp = comp_dists.logp(initial_point["mixture"].reshape(1, 3))
961-
log_sum_exp = logsumexp(
962-
comp_logp.eval() + np.log(w)[..., None], axis=0, keepdims=True
963-
).sum()
958+
comp_logp = logp(comp_dists, initial_point["mixture"].reshape(1, 3))
959+
log_sum_exp = logsumexp(comp_logp.eval() + np.log(w), axis=0, keepdims=True).sum()
964960
assert_allclose(
965-
model.logp(initial_point),
961+
model.compile_logp()(initial_point),
966962
log_sum_exp,
967963
rtol,
968964
)
@@ -971,7 +967,7 @@ def test_broadcasting_in_shape(self):
971967
with Model() as model:
972968
mu = Gamma("mu", 1.0, 1.0, shape=2)
973969
comp_dists = Poisson.dist(mu, shape=2)
974-
mix = MixtureSameFamily("mix", w=np.ones(2) / 2, comp_dists=comp_dists, shape=(1000,))
975-
prior = sample_prior_predictive(samples=self.n_samples)
970+
mix = Mixture("mix", w=np.ones(2) / 2, comp_dists=comp_dists, shape=(1000,))
971+
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
976972

977973
assert prior["mix"].shape == (self.n_samples, 1000)

0 commit comments

Comments
 (0)