File tree 1 file changed +5
-12
lines changed
1 file changed +5
-12
lines changed Original file line number Diff line number Diff line change @@ -140,20 +140,13 @@ def test_marginal_likelihood(self):
140
140
with pm .Model () as model :
141
141
a = pm .Beta ("a" , alpha , beta )
142
142
y = pm .Bernoulli ("y" , a , observed = data )
143
- trace = pm .sample_smc (2000 , return_inferencedata = False )
144
- marginals .append (trace .report .log_marginal_likelihood )
143
+ trace = pm .sample_smc (2000 , chains = 2 , return_inferencedata = False )
144
+ # log_marignal_likelihood is found in the last value of each chain
145
+ lml = np .mean ([chain [- 1 ] for chain in trace .report .log_marginal_likelihood ])
146
+ marginals .append (lml )
145
147
146
148
# compare to the analytical result
147
- assert (
148
- np .abs (
149
- np .exp (
150
- np .nanmean (np .array (marginals [1 ], dtype = float ))
151
- - np .nanmean (np .array (marginals [0 ], dtype = float ))
152
- - 4.0
153
- )
154
- )
155
- <= 1
156
- )
149
+ assert abs (np .exp (marginals [1 ] - marginals [0 ]) - 4.0 ) <= 1
157
150
158
151
def test_start (self ):
159
152
with pm .Model () as model :
You can’t perform that action at this time.
0 commit comments