Skip to content

Commit 333f7f3

Browse files
committed
Fix TestSMC.test_marginal_likelihood
When chains have different lengths, this test would fail because the resulting `log_marginal_likelihood` would be non-square
1 parent 1ab0757 commit 333f7f3

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

pymc/tests/test_smc.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,13 @@ def test_marginal_likelihood(self):
140140
with pm.Model() as model:
141141
a = pm.Beta("a", alpha, beta)
142142
y = pm.Bernoulli("y", a, observed=data)
143-
trace = pm.sample_smc(2000, return_inferencedata=False)
144-
marginals.append(trace.report.log_marginal_likelihood)
143+
trace = pm.sample_smc(2000, chains=2, return_inferencedata=False)
144+
# log_marignal_likelihood is found in the last value of each chain
145+
lml = np.mean([chain[-1] for chain in trace.report.log_marginal_likelihood])
146+
marginals.append(lml)
145147

146148
# compare to the analytical result
147-
assert (
148-
np.abs(
149-
np.exp(
150-
np.nanmean(np.array(marginals[1], dtype=float))
151-
- np.nanmean(np.array(marginals[0], dtype=float))
152-
- 4.0
153-
)
154-
)
155-
<= 1
156-
)
149+
assert abs(np.exp(marginals[1] - marginals[0]) - 4.0) <= 1
157150

158151
def test_start(self):
159152
with pm.Model() as model:

0 commit comments

Comments
 (0)