Skip to content

Commit 04baabd

Browse files
larryshamalamaricardoV94
authored andcommitted
Add Mixture moments
1 parent 34f4679 commit 04baabd

File tree

2 files changed

+216
-1
lines changed

2 files changed

+216
-1
lines changed

pymc/distributions/mixture.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@
2626
from pymc.aesaraf import change_rv_size
2727
from pymc.distributions.continuous import Normal, get_tau_sigma
2828
from pymc.distributions.dist_math import check_parameters
29-
from pymc.distributions.distribution import Discrete, Distribution, SymbolicDistribution
29+
from pymc.distributions.distribution import (
30+
Discrete,
31+
Distribution,
32+
SymbolicDistribution,
33+
_get_moment,
34+
get_moment,
35+
)
3036
from pymc.distributions.logprob import logp
3137
from pymc.distributions.shape_utils import to_tuple
3238
from pymc.util import check_dist_not_registered
@@ -398,6 +404,24 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
398404
return mix_logp
399405

400406

407+
@_get_moment.register(MarginalMixtureRV)
408+
def get_moment_marginal_mixture(op, rv, rng, weights, *components):
409+
ndim_supp = components[0].owner.op.ndim_supp
410+
weights = at.shape_padright(weights, ndim_supp)
411+
mix_axis = -ndim_supp - 1
412+
413+
if len(components) == 1:
414+
moment_components = get_moment(components[0])
415+
416+
else:
417+
moment_components = at.stack(
418+
[get_moment(component) for component in components],
419+
axis=mix_axis,
420+
)
421+
422+
return at.sum(weights * moment_components, axis=mix_axis)
423+
424+
401425
class NormalMixture:
402426
R"""
403427
Normal mixture log-likelihood

pymc/tests/test_mixture.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pymc.step_methods import Metropolis
5454
from pymc.tests.helpers import SeededTest
5555
from pymc.tests.test_distributions import Domain, Simplex
56+
from pymc.tests.test_distributions_moments import assert_moment_is_expected
5657
from pymc.tests.test_distributions_random import pymc_random
5758

5859

@@ -971,3 +972,193 @@ def test_broadcasting_in_shape(self):
971972
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
972973

973974
assert prior["mix"].shape == (self.n_samples, 1000)
975+
976+
977+
class TestMixtureMoments:
978+
@pytest.mark.parametrize(
979+
"weights, comp_dists, size, expected",
980+
[
981+
(
982+
np.array([0.4, 0.6]),
983+
Normal.dist(mu=np.array([-2, 6]), sigma=np.array([5, 3])),
984+
None,
985+
2.8,
986+
),
987+
(
988+
np.tile(1 / 13, 13),
989+
Normal.dist(-2, 1, size=(13,)),
990+
(3,),
991+
np.full((3,), -2),
992+
),
993+
(
994+
np.array([0.4, 0.6]),
995+
Normal.dist([-2, 6], 3),
996+
(5, 3),
997+
np.full((5, 3), 2.8),
998+
),
999+
(
1000+
np.broadcast_to(np.array([0.4, 0.6]), (5, 3, 2)),
1001+
Normal.dist(np.array([-2, 6]), np.array([5, 3])),
1002+
None,
1003+
np.full(shape=(5, 3), fill_value=2.8),
1004+
),
1005+
(
1006+
np.array([0.4, 0.6]),
1007+
Normal.dist(np.array([-2, 6]), np.array([5, 3]), size=(5, 3, 2)),
1008+
None,
1009+
np.full(shape=(5, 3), fill_value=2.8),
1010+
),
1011+
(
1012+
np.array([[0.8, 0.2], [0.2, 0.8]]),
1013+
Normal.dist(np.array([-2, 6])),
1014+
None,
1015+
np.array([-0.4, 4.4]),
1016+
),
1017+
# implied size = (11, 7) will be overwritten by (5, 3)
1018+
(
1019+
np.array([0.4, 0.6]),
1020+
Normal.dist(np.array([-2, 6]), np.array([5, 3]), size=(11, 7, 2)),
1021+
(5, 3),
1022+
np.full(shape=(5, 3), fill_value=2.8),
1023+
),
1024+
],
1025+
)
1026+
def test_single_univariate_component(self, weights, comp_dists, size, expected):
1027+
with Model() as model:
1028+
Mixture("x", weights, comp_dists, size=size)
1029+
assert_moment_is_expected(model, expected, check_finite_logp=False)
1030+
1031+
@pytest.mark.parametrize(
1032+
"weights, comp_dists, size, expected",
1033+
[
1034+
(
1035+
np.array([1, 0]),
1036+
[Normal.dist(-2, 5), Normal.dist(6, 3)],
1037+
None,
1038+
-2,
1039+
),
1040+
(
1041+
np.array([0.4, 0.6]),
1042+
[Normal.dist(-2, 5, size=(2,)), Normal.dist(6, 3, size=(2,))],
1043+
None,
1044+
np.full((2,), 2.8),
1045+
),
1046+
(
1047+
np.array([0.5, 0.5]),
1048+
[Normal.dist(-2, 5), Exponential.dist(lam=1 / 3)],
1049+
(3, 5),
1050+
np.full((3, 5), 0.5),
1051+
),
1052+
(
1053+
np.broadcast_to(np.array([0.4, 0.6]), (5, 3, 2)),
1054+
[Normal.dist(-2, 5), Normal.dist(6, 3)],
1055+
None,
1056+
np.full(shape=(5, 3), fill_value=2.8),
1057+
),
1058+
(
1059+
np.array([[0.8, 0.2], [0.2, 0.8]]),
1060+
[Normal.dist(-2, 5), Normal.dist(6, 3)],
1061+
None,
1062+
np.array([-0.4, 4.4]),
1063+
),
1064+
(
1065+
np.array([[0.8, 0.2], [0.2, 0.8]]),
1066+
[Normal.dist(-2, 5), Normal.dist(6, 3)],
1067+
(3, 2),
1068+
np.full((3, 2), np.array([-0.4, 4.4])),
1069+
),
1070+
(
1071+
# implied size = (11, 7) will be overwritten by (5, 3)
1072+
np.array([0.4, 0.6]),
1073+
[Normal.dist(-2, 5, size=(11, 7)), Normal.dist(6, 3, size=(11, 7))],
1074+
(5, 3),
1075+
np.full(shape=(5, 3), fill_value=2.8),
1076+
),
1077+
],
1078+
)
1079+
def test_list_univariate_components(self, weights, comp_dists, size, expected):
1080+
with Model() as model:
1081+
Mixture("x", weights, comp_dists, size=size)
1082+
assert_moment_is_expected(model, expected, check_finite_logp=False)
1083+
1084+
@pytest.mark.parametrize(
1085+
"weights, comp_dists, size, expected",
1086+
[
1087+
(
1088+
np.array([0.4, 0.6]),
1089+
MvNormal.dist(mu=np.array([[-1, -2], [3, 5]]), cov=np.eye(2) * 0.3),
1090+
None,
1091+
np.array([1.4, 2.2]),
1092+
),
1093+
(
1094+
np.array([0.5, 0.5]),
1095+
Dirichlet.dist(a=np.array([[0.0001, 0.0001, 1000], [2, 4, 6]])),
1096+
(4,),
1097+
np.array(np.full((4, 3), [1 / 12, 1 / 6, 3 / 4])),
1098+
),
1099+
(
1100+
np.array([0.4, 0.6]),
1101+
MvNormal.dist(mu=np.array([-10, 0, 10]), cov=np.eye(3) * 3, size=(4, 2)),
1102+
None,
1103+
np.full((4, 3), [-10, 0, 10]),
1104+
),
1105+
(
1106+
np.array([[1.0, 0], [0.0, 1.0]]),
1107+
MvNormal.dist(
1108+
mu=np.array([[-5, -10, -15], [5, 10, 15]]), cov=np.eye(3) * 3, size=(2,)
1109+
),
1110+
(3, 2),
1111+
np.full((3, 2, 3), [[-5, -10, -15], [5, 10, 15]]),
1112+
),
1113+
],
1114+
)
1115+
def test_single_multivariate_component(self, weights, comp_dists, size, expected):
1116+
with Model() as model:
1117+
Mixture("x", weights, comp_dists, size=size)
1118+
assert_moment_is_expected(model, expected, check_finite_logp=False)
1119+
1120+
@pytest.mark.parametrize(
1121+
"weights, comp_dists, size, expected",
1122+
[
1123+
(
1124+
np.array([0.4, 0.6]),
1125+
[
1126+
MvNormal.dist(mu=np.array([-1, -2]), cov=np.eye(2) * 0.3),
1127+
MvNormal.dist(mu=np.array([3, 5]), cov=np.eye(2) * 0.8),
1128+
],
1129+
None,
1130+
np.array([1.4, 2.2]),
1131+
),
1132+
(
1133+
np.array([0.4, 0.6]),
1134+
[
1135+
Dirichlet.dist(a=np.array([2, 3, 5])),
1136+
MvNormal.dist(mu=np.array([-10, 0, 10]), cov=np.eye(3) * 3),
1137+
],
1138+
(4,),
1139+
np.array(np.full((4, 3), [-5.92, 0.12, 6.2])),
1140+
),
1141+
(
1142+
np.array([0.4, 0.6]),
1143+
[
1144+
Dirichlet.dist(a=np.array([2, 3, 5]), size=(2,)),
1145+
MvNormal.dist(mu=np.array([-10, 0, 10]), cov=np.eye(3) * 3, size=(2,)),
1146+
],
1147+
None,
1148+
np.full((2, 3), [-5.92, 0.12, 6.2]),
1149+
),
1150+
(
1151+
np.array([[1.0, 0], [0.0, 1.0]]),
1152+
[
1153+
MvNormal.dist(mu=np.array([-5, -10, -15]), cov=np.eye(3) * 3, size=(2,)),
1154+
MvNormal.dist(mu=np.array([5, 10, 15]), cov=np.eye(3) * 3, size=(2,)),
1155+
],
1156+
(3, 2),
1157+
np.full((3, 2, 3), [[-5, -10, -15], [5, 10, 15]]),
1158+
),
1159+
],
1160+
)
1161+
def test_list_multivariate_components(self, weights, comp_dists, size, expected):
1162+
with Model() as model:
1163+
Mixture("x", weights, comp_dists, size=size)
1164+
assert_moment_is_expected(model, expected, check_finite_logp=False)

0 commit comments

Comments
 (0)