Skip to content

Commit 4aeedca

Browse files
committed
Update logp calc
1 parent b53b7e7 commit 4aeedca

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

pymc/distributions/timeseries.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from pymc.distributions import distribution, multivariate
2424
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
25+
from pymc.distributions.dist_math import check_parameters
2526
from pymc.distributions.shape_utils import to_tuple
2627

2728
__all__ = [
@@ -297,11 +298,6 @@ def logp(
297298
-------
298299
TensorVariable
299300
"""
300-
import pymc as pm
301-
302-
# Implement using AePPL
303-
# I need to create a graph that calculates the logp of GRW
304-
# I can use AePPL or PyMC to do it
305301

306302
def normal_logp(value, mu, sigma):
307303
logp = (
@@ -315,21 +311,13 @@ def normal_logp(value, mu, sigma):
315311
init_logp = normal_logp(value[0] - init, mu, sigma)
316312

317313
# Create logp calculation graph for innovations
318-
stationary_vals = at.diff(value[1:]) - init
314+
stationary_vals = at.diff(value[1:])
319315
innov_logp = normal_logp(stationary_vals, mu, sigma)
320316

321-
""" A bunch of stuff that can be ignored
322-
innit_logp = pm.logp(pm.Normal.dist(mu, sigma), value[:1] - init)
323-
# https: // aesara.readthedocs.io / en / latest / library / tensor / extra_ops.html?highlight = at.diff
324-
innov_logp = pm.logp(pm.Normal.dist(mu, sigma), at.diff(value))
325-
# https: // numpy.org/doc/stable/ reference / generated / numpy.concatenate.html
326-
"""
327-
328-
# Return both calculation logps in a vector. This is fine because somewhere
329-
# down the line these will be summed together
317+
# Return both calculation logps in a vector
330318
total_logp = at.concatenate([init_logp, innov_logp])
331-
# total_logp = at.sum([init_logp, innov_logp], keepdims=False)
332319

320+
total_logp = check_parameters(total_logp, sigma > 0, msg="sigma > 0")
333321
return total_logp
334322

335323

0 commit comments

Comments
 (0)