Skip to content

Commit fe9cb3d

Browse files
committed
Update to limit support to univariate time series
1 parent b6fa76b commit fe9cb3d

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

pymc/distributions/timeseries.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -912,8 +912,8 @@ class EulerMaruyama(Distribution):
912912
sde_pars: tuple
913913
parameters of the SDE, passed as ``*args`` to ``sde_fn``
914914
init_dist : unnamed distribution, optional
915-
Scalar or vector distribution for initial values. Unnamed refers to distributions
916-
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
915+
Scalar distribution for initial values. Unnamed refers to distributions created with
916+
the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
917917
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
918918
919919
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
@@ -953,9 +953,9 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
953953
f"got {type(init_dist)}"
954954
)
955955
check_dist_not_registered(init_dist)
956-
if init_dist.owner.op.ndim_supp > 1:
956+
if init_dist.owner.op.ndim_supp > 0:
957957
raise ValueError(
958-
"Init distribution must have a scalar or vector support dimension, ",
958+
"Init distribution must have a scalar support dimension, ",
959959
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
960960
)
961961
else:

pymc/tests/distributions/test_timeseries.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,10 @@ def sde_fn(x, k, d, s):
847847
sde_pars = [1.0, 2.0, 0.1]
848848
sde_pars[batched_param] = sde_pars[batched_param] * param_val
849849
with Model() as t0:
850-
y = EulerMaruyama("y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs)
850+
init_dist = pm.Normal.dist(0, 10, shape=(batch_size,))
851+
y = EulerMaruyama(
852+
"y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs
853+
)
851854

852855
y_eval = draw(y, draws=2)
853856
assert y_eval[0].shape == (batch_size, steps)
@@ -859,7 +862,15 @@ def sde_fn(x, k, d, s):
859862
for i in range(batch_size):
860863
sde_pars_slice = sde_pars.copy()
861864
sde_pars_slice[batched_param] = sde_pars[batched_param][i]
862-
EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars_slice, **kwargs)
865+
init_dist = pm.Normal.dist(0, 10)
866+
EulerMaruyama(
867+
f"y_{i}",
868+
dt=0.02,
869+
sde_fn=sde_fn,
870+
sde_pars=sde_pars_slice,
871+
init_dist=init_dist,
872+
**kwargs,
873+
)
863874

864875
t0_init = t0.initial_point()
865876
t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)}
@@ -872,7 +883,13 @@ def test_change_dist_size1(self):
872883
def sde1(x, k, d, s):
873884
return (k - d * x, s)
874885

875-
base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde1, sde_pars=(1, 2, 0.1), shape=(5, 10))
886+
base_dist = EulerMaruyama.dist(
887+
dt=0.01,
888+
sde_fn=sde1,
889+
sde_pars=(1, 2, 0.1),
890+
init_dist=pm.Normal.dist(0, 10),
891+
shape=(5, 10),
892+
)
876893

877894
new_dist = change_dist_size(base_dist, (4,))
878895
assert new_dist.eval().shape == (4, 10)
@@ -885,7 +902,9 @@ def sde2(p, s):
885902
N = 500.0
886903
return s * p * (1 - p) / (1 + s * p), pm.math.sqrt(p * (1 - p) / N)
887904

888-
base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde2, sde_pars=(0.1,), shape=(3, 10))
905+
base_dist = EulerMaruyama.dist(
906+
dt=0.01, sde_fn=sde2, sde_pars=(0.1,), init_dist=pm.Normal.dist(0, 10), shape=(3, 10)
907+
)
889908

890909
new_dist = change_dist_size(base_dist, (4,))
891910
assert new_dist.eval().shape == (4, 10)
@@ -913,7 +932,9 @@ def _gen_sde_path(sde, pars, dt, n, x0):
913932
# build model
914933
with Model() as model:
915934
lamh = Flat("lamh")
916-
xh = EulerMaruyama("xh", dt, sde, (lamh,), steps=N, initval=x)
935+
xh = EulerMaruyama(
936+
"xh", dt, sde, (lamh,), steps=N, initval=x, init_dist=pm.Normal.dist(0, 10)
937+
)
917938
Normal("zh", mu=xh, sigma=sig2, observed=z)
918939
# invert
919940
with model:

0 commit comments

Comments
 (0)