|
53 | 53 | from pymc.step_methods import Metropolis
|
54 | 54 | from pymc.tests.helpers import SeededTest
|
55 | 55 | from pymc.tests.test_distributions import Domain, Simplex
|
| 56 | +from pymc.tests.test_distributions_moments import assert_moment_is_expected |
56 | 57 | from pymc.tests.test_distributions_random import pymc_random
|
57 | 58 |
|
58 | 59 |
|
@@ -971,3 +972,193 @@ def test_broadcasting_in_shape(self):
|
971 | 972 | prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
|
972 | 973 |
|
973 | 974 | assert prior["mix"].shape == (self.n_samples, 1000)
|
| 975 | + |
| 976 | + |
| 977 | +class TestMixtureMoments: |
| 978 | + @pytest.mark.parametrize( |
| 979 | + "weights, comp_dists, size, expected", |
| 980 | + [ |
| 981 | + ( |
| 982 | + np.array([0.4, 0.6]), |
| 983 | + Normal.dist(mu=np.array([-2, 6]), sigma=np.array([5, 3])), |
| 984 | + None, |
| 985 | + 2.8, |
| 986 | + ), |
| 987 | + ( |
| 988 | + np.tile(1 / 13, 13), |
| 989 | + Normal.dist(-2, 1, size=(13,)), |
| 990 | + (3,), |
| 991 | + np.full((3,), -2), |
| 992 | + ), |
| 993 | + ( |
| 994 | + np.array([0.4, 0.6]), |
| 995 | + Normal.dist([-2, 6], 3), |
| 996 | + (5, 3), |
| 997 | + np.full((5, 3), 2.8), |
| 998 | + ), |
| 999 | + ( |
| 1000 | + np.broadcast_to(np.array([0.4, 0.6]), (5, 3, 2)), |
| 1001 | + Normal.dist(np.array([-2, 6]), np.array([5, 3])), |
| 1002 | + None, |
| 1003 | + np.full(shape=(5, 3), fill_value=2.8), |
| 1004 | + ), |
| 1005 | + ( |
| 1006 | + np.array([0.4, 0.6]), |
| 1007 | + Normal.dist(np.array([-2, 6]), np.array([5, 3]), size=(5, 3, 2)), |
| 1008 | + None, |
| 1009 | + np.full(shape=(5, 3), fill_value=2.8), |
| 1010 | + ), |
| 1011 | + ( |
| 1012 | + np.array([[0.8, 0.2], [0.2, 0.8]]), |
| 1013 | + Normal.dist(np.array([-2, 6])), |
| 1014 | + None, |
| 1015 | + np.array([-0.4, 4.4]), |
| 1016 | + ), |
| 1017 | + # implied size = (11, 7) will be overwritten by (5, 3) |
| 1018 | + ( |
| 1019 | + np.array([0.4, 0.6]), |
| 1020 | + Normal.dist(np.array([-2, 6]), np.array([5, 3]), size=(11, 7, 2)), |
| 1021 | + (5, 3), |
| 1022 | + np.full(shape=(5, 3), fill_value=2.8), |
| 1023 | + ), |
| 1024 | + ], |
| 1025 | + ) |
| 1026 | + def test_single_univariate_component(self, weights, comp_dists, size, expected): |
| 1027 | + with Model() as model: |
| 1028 | + Mixture("x", weights, comp_dists, size=size) |
| 1029 | + assert_moment_is_expected(model, expected, check_finite_logp=False) |
| 1030 | + |
| 1031 | + @pytest.mark.parametrize( |
| 1032 | + "weights, comp_dists, size, expected", |
| 1033 | + [ |
| 1034 | + ( |
| 1035 | + np.array([1, 0]), |
| 1036 | + [Normal.dist(-2, 5), Normal.dist(6, 3)], |
| 1037 | + None, |
| 1038 | + -2, |
| 1039 | + ), |
| 1040 | + ( |
| 1041 | + np.array([0.4, 0.6]), |
| 1042 | + [Normal.dist(-2, 5, size=(2,)), Normal.dist(6, 3, size=(2,))], |
| 1043 | + None, |
| 1044 | + np.full((2,), 2.8), |
| 1045 | + ), |
| 1046 | + ( |
| 1047 | + np.array([0.5, 0.5]), |
| 1048 | + [Normal.dist(-2, 5), Exponential.dist(lam=1 / 3)], |
| 1049 | + (3, 5), |
| 1050 | + np.full((3, 5), 0.5), |
| 1051 | + ), |
| 1052 | + ( |
| 1053 | + np.broadcast_to(np.array([0.4, 0.6]), (5, 3, 2)), |
| 1054 | + [Normal.dist(-2, 5), Normal.dist(6, 3)], |
| 1055 | + None, |
| 1056 | + np.full(shape=(5, 3), fill_value=2.8), |
| 1057 | + ), |
| 1058 | + ( |
| 1059 | + np.array([[0.8, 0.2], [0.2, 0.8]]), |
| 1060 | + [Normal.dist(-2, 5), Normal.dist(6, 3)], |
| 1061 | + None, |
| 1062 | + np.array([-0.4, 4.4]), |
| 1063 | + ), |
| 1064 | + ( |
| 1065 | + np.array([[0.8, 0.2], [0.2, 0.8]]), |
| 1066 | + [Normal.dist(-2, 5), Normal.dist(6, 3)], |
| 1067 | + (3, 2), |
| 1068 | + np.full((3, 2), np.array([-0.4, 4.4])), |
| 1069 | + ), |
| 1070 | + ( |
| 1071 | + # implied size = (11, 7) will be overwritten by (5, 3) |
| 1072 | + np.array([0.4, 0.6]), |
| 1073 | + [Normal.dist(-2, 5, size=(11, 7)), Normal.dist(6, 3, size=(11, 7))], |
| 1074 | + (5, 3), |
| 1075 | + np.full(shape=(5, 3), fill_value=2.8), |
| 1076 | + ), |
| 1077 | + ], |
| 1078 | + ) |
| 1079 | + def test_list_univariate_components(self, weights, comp_dists, size, expected): |
| 1080 | + with Model() as model: |
| 1081 | + Mixture("x", weights, comp_dists, size=size) |
| 1082 | + assert_moment_is_expected(model, expected, check_finite_logp=False) |
| 1083 | + |
| 1084 | + @pytest.mark.parametrize( |
| 1085 | + "weights, comp_dists, size, expected", |
| 1086 | + [ |
| 1087 | + ( |
| 1088 | + np.array([0.4, 0.6]), |
| 1089 | + MvNormal.dist(mu=np.array([[-1, -2], [3, 5]]), cov=np.eye(2) * 0.3), |
| 1090 | + None, |
| 1091 | + np.array([1.4, 2.2]), |
| 1092 | + ), |
| 1093 | + ( |
| 1094 | + np.array([0.5, 0.5]), |
| 1095 | + Dirichlet.dist(a=np.array([[0.0001, 0.0001, 1000], [2, 4, 6]])), |
| 1096 | + (4,), |
| 1097 | + np.array(np.full((4, 3), [1 / 12, 1 / 6, 3 / 4])), |
| 1098 | + ), |
| 1099 | + ( |
| 1100 | + np.array([0.4, 0.6]), |
| 1101 | + MvNormal.dist(mu=np.array([-10, 0, 10]), cov=np.eye(3) * 3, size=(4, 2)), |
| 1102 | + None, |
| 1103 | + np.full((4, 3), [-10, 0, 10]), |
| 1104 | + ), |
| 1105 | + ( |
| 1106 | + np.array([[1.0, 0], [0.0, 1.0]]), |
| 1107 | + MvNormal.dist( |
| 1108 | + mu=np.array([[-5, -10, -15], [5, 10, 15]]), cov=np.eye(3) * 3, size=(2,) |
| 1109 | + ), |
| 1110 | + (3, 2), |
| 1111 | + np.full((3, 2, 3), [[-5, -10, -15], [5, 10, 15]]), |
| 1112 | + ), |
| 1113 | + ], |
| 1114 | + ) |
| 1115 | + def test_single_multivariate_component(self, weights, comp_dists, size, expected): |
| 1116 | + with Model() as model: |
| 1117 | + Mixture("x", weights, comp_dists, size=size) |
| 1118 | + assert_moment_is_expected(model, expected, check_finite_logp=False) |
| 1119 | + |
| 1120 | + @pytest.mark.parametrize( |
| 1121 | + "weights, comp_dists, size, expected", |
| 1122 | + [ |
| 1123 | + ( |
| 1124 | + np.array([0.4, 0.6]), |
| 1125 | + [ |
| 1126 | + MvNormal.dist(mu=np.array([-1, -2]), cov=np.eye(2) * 0.3), |
| 1127 | + MvNormal.dist(mu=np.array([3, 5]), cov=np.eye(2) * 0.8), |
| 1128 | + ], |
| 1129 | + None, |
| 1130 | + np.array([1.4, 2.2]), |
| 1131 | + ), |
| 1132 | + ( |
| 1133 | + np.array([0.4, 0.6]), |
| 1134 | + [ |
| 1135 | + Dirichlet.dist(a=np.array([2, 3, 5])), |
| 1136 | + MvNormal.dist(mu=np.array([-10, 0, 10]), cov=np.eye(3) * 3), |
| 1137 | + ], |
| 1138 | + (4,), |
| 1139 | + np.array(np.full((4, 3), [-5.92, 0.12, 6.2])), |
| 1140 | + ), |
| 1141 | + ( |
| 1142 | + np.array([0.4, 0.6]), |
| 1143 | + [ |
| 1144 | + Dirichlet.dist(a=np.array([2, 3, 5]), size=(2,)), |
| 1145 | + MvNormal.dist(mu=np.array([-10, 0, 10]), cov=np.eye(3) * 3, size=(2,)), |
| 1146 | + ], |
| 1147 | + None, |
| 1148 | + np.full((2, 3), [-5.92, 0.12, 6.2]), |
| 1149 | + ), |
| 1150 | + ( |
| 1151 | + np.array([[1.0, 0], [0.0, 1.0]]), |
| 1152 | + [ |
| 1153 | + MvNormal.dist(mu=np.array([-5, -10, -15]), cov=np.eye(3) * 3, size=(2,)), |
| 1154 | + MvNormal.dist(mu=np.array([5, 10, 15]), cov=np.eye(3) * 3, size=(2,)), |
| 1155 | + ], |
| 1156 | + (3, 2), |
| 1157 | + np.full((3, 2, 3), [[-5, -10, -15], [5, 10, 15]]), |
| 1158 | + ), |
| 1159 | + ], |
| 1160 | + ) |
| 1161 | + def test_list_multivariate_components(self, weights, comp_dists, size, expected): |
| 1162 | + with Model() as model: |
| 1163 | + Mixture("x", weights, comp_dists, size=size) |
| 1164 | + assert_moment_is_expected(model, expected, check_finite_logp=False) |
0 commit comments