Skip to content

Commit 620b11d

Browse files
Refactor Mixture distribution for V4 (#5438)
* Set `expand=True` when calling `change_size` in `SymbolicDistribution` * Move NormalMixture tests to their own class * Move mixture random tests from test_distributions_random to test_mixture * Use specific imports in test_mixture * Reenable Mixture tests in pytest workflow * Refactor Mixture distribution Mixtures now use an `OpFromGraph` that encapsulates the Aesara random method. This is used so that logp can be easily dispatched to the distribution without requiring involved pattern matching. The Mixture random and logp methods now fully respect the support dimensionality of its components, whereas previously only the logp method did, leading to inconsistencies between the two methods. In the case where the weights (or size) indicate the need for more draws than what is given by the component distributions, the latter are resized to ensure there are no repeated draws. This refactoring forces Mixture components to be basic RandomVariables, meaning that nested Mixtures or Mixtures of Symbolic distributions (like Censored) are not currently possible. Co-authored-by: Larry Dong <[email protected]> * Add warning when using iterable with single Mixture component * Update Mixture docstrings * Emphasize equivalency between iterable of components and single batched component * Add example with mixture of two distinct distributions * Add example with multivariate components * Refactor NormalMixture * Refactor TestMixtureVsLatent The two tests relied on implicit behavior of V3, where the dimensionality of the weights implied the support dimension of mixture distribution. This, however, led to inconsistent behavior between the random method and the logp, as the latter did not enforce this assumption, and did not distinguish if values were mixed across the implied support dimension. In this refactoring, the support dimensionality of the component variables determines the dimensionality of the mixture distribution, regardless of the weights. This leads to consistent behavior between the random and logp methods as asserted by the new checks. Future work will explore allowing the user to specify an artificial support dimensionality that is higher than the one implied by the component distributions, but this is for now not possible. * Remove MixtureSameFamily Behavior is now implemented in Mixture * Add Mixture moments * Update release notes Co-authored-by: Larry Dong <[email protected]>
1 parent 7d4162c commit 620b11d

File tree

9 files changed

+1241
-1138
lines changed

9 files changed

+1241
-1138
lines changed

.github/workflows/pytest.yml

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ jobs:
7878
pymc/tests/test_transforms.py
7979
pymc/tests/test_smc.py
8080
pymc/tests/test_bart.py
81+
pymc/tests/test_mixture.py
8182
8283
- |
8384
pymc/tests/test_parallel_sampling.py

RELEASE-NOTES.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Instead update the vNext section until 4.0.0 is out.
1313
### Not-yet working features
1414
We plan to get these working again, but at this point their inner workings have not been refactored.
1515
- Timeseries distributions (see [#4642](https://github.com/pymc-devs/pymc/issues/4642))
16-
- Mixture distributions (see [#4781](https://github.com/pymc-devs/pymc/issues/4781))
16+
- Nested Mixture distributions (see [#5533](https://github.com/pymc-devs/pymc/issues/5533))
1717
- Elliptical slice sampling (see [#5137](https://github.com/pymc-devs/pymc/issues/5137))
1818
- `BaseStochasticGradient` (see [#5138](https://github.com/pymc-devs/pymc/issues/5138))
1919
- `pm.sample_posterior_predictive_w` (see [#4807](https://github.com/pymc-devs/pymc/issues/4807))
@@ -72,6 +72,7 @@ All of the above apply to:
7272
- In the gp.utils file, the `kmeans_inducing_points` function now passes through `kmeans_kwargs` to scipy's k-means function.
7373
- The function `replace_with_values` function has been added to `gp.utils`.
7474
- `MarginalSparse` has been renamed `MarginalApprox`.
75+
- Removed `MixtureSameFamily`. `Mixture` is now capable of handling batched multivariate components (see [#5438](https://github.com/pymc-devs/pymc/pull/5438)).
7576
- ...
7677

7778
### Expected breaks

docs/source/api/distributions/mixture.rst

-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ Mixture
88

99
Mixture
1010
NormalMixture
11-
MixtureSameFamily

pymc/distributions/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
NoDistribution,
8484
SymbolicDistribution,
8585
)
86-
from pymc.distributions.mixture import Mixture, MixtureSameFamily, NormalMixture
86+
from pymc.distributions.mixture import Mixture, NormalMixture
8787
from pymc.distributions.multivariate import (
8888
CAR,
8989
Dirichlet,
@@ -180,7 +180,6 @@
180180
"SkewNormal",
181181
"Mixture",
182182
"NormalMixture",
183-
"MixtureSameFamily",
184183
"Triangular",
185184
"DiscreteWeibull",
186185
"Gumbel",

pymc/distributions/distribution.py

+1
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def __new__(
500500
rv_out = cls.change_size(
501501
rv=rv_out,
502502
new_size=resize_shape,
503+
expand=True,
503504
)
504505

505506
rv_out = model.register_rv(

0 commit comments

Comments
 (0)