Skip to content

Commit a288c72

Browse files
committed
Manage SharedVariables explicitly in SymbolicRandomVariable
1 parent 7c9aaac commit a288c72

File tree

5 files changed

+74
-41
lines changed

5 files changed

+74
-41
lines changed

pymc/distributions/distribution.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def __init__(
299299
raise ValueError("ndim_supp or gufunc_signature must be provided")
300300

301301
kwargs.setdefault("inline", True)
302+
kwargs.setdefault("strict", True)
302303
super().__init__(*args, **kwargs)
303304

304305
def update(self, node: Node) -> dict[Variable, Variable]:
@@ -702,7 +703,7 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):
702703
symbolic random methods.
703704
"""
704705

705-
default_output = -1
706+
default_output = 0
706707

707708
_print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}")
708709

@@ -805,14 +806,17 @@ def rv_op(
805806
if logp is not None:
806807

807808
@_logprob.register(rv_type)
808-
def custom_dist_logp(op, values, size, *params, **kwargs):
809-
return logp(values[0], *params[: len(dist_params)])
809+
def custom_dist_logp(op, values, size, *inputs, **kwargs):
810+
[value] = values
811+
rv_params = inputs[: len(dist_params)]
812+
return logp(value, *rv_params)
810813

811814
if logcdf is not None:
812815

813816
@_logcdf.register(rv_type)
814-
def custom_dist_logcdf(op, value, size, *params, **kwargs):
815-
return logcdf(value, *params[: len(dist_params)])
817+
def custom_dist_logcdf(op, value, size, *inputs, **kwargs):
818+
rv_params = inputs[: len(dist_params)]
819+
return logcdf(value, *rv_params)
816820

817821
if support_point is not None:
818822

@@ -845,22 +849,29 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
845849
dummy_dist_params = [dist_param.type() for dist_param in old_dist_params]
846850
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
847851
dummy_params = [dummy_size_param, *dummy_dist_params]
848-
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
852+
updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
853+
rngs = updates_dict.keys()
854+
rngs_updates = updates_dict.values()
849855
new_rv_op = rv_type(
850-
inputs=dummy_params,
851-
outputs=[*dummy_updates_dict.values(), dummy_rv],
856+
inputs=[*dummy_params, *rngs],
857+
outputs=[dummy_rv, *rngs_updates],
852858
signature=signature,
853859
)
854-
new_rv = new_rv_op(new_size, *dist_params)
860+
new_rv = new_rv_op(new_size, *dist_params, *rngs)
855861

856862
return new_rv
857863

864+
# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
865+
# We retrieve them here
866+
updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
867+
rngs = updates_dict.keys()
868+
rngs_updates = updates_dict.values()
858869
rv_op = rv_type(
859-
inputs=dummy_params,
860-
outputs=[*dummy_updates_dict.values(), dummy_rv],
870+
inputs=[*dummy_params, *rngs],
871+
outputs=[dummy_rv, *rngs_updates],
861872
signature=signature,
862873
)
863-
return rv_op(size, *dist_params)
874+
return rv_op(size, *dist_params, *rngs)
864875

865876
@staticmethod
866877
def _infer_final_signature(signature: str, n_inputs, n_updates) -> str:

pymc/distributions/timeseries.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,6 @@ def __init__(self, *args, ar_order, constant_term, **kwargs):
436436

437437
def update(self, node: Node):
438438
"""Return the update mapping for the noise RV."""
439-
# Since noise is a shared variable it shows up as the last node input
440439
return {node.inputs[-1]: node.outputs[0]}
441440

442441

@@ -658,13 +657,13 @@ def step(*args):
658657
ar_ = pt.concatenate([init_, innov_.T], axis=-1)
659658

660659
ar_op = AutoRegressiveRV(
661-
inputs=[rhos_, sigma_, init_, steps_],
660+
inputs=[rhos_, sigma_, init_, steps_, noise_rng],
662661
outputs=[noise_next_rng, ar_],
663662
ar_order=ar_order,
664663
constant_term=constant_term,
665664
)
666665

667-
ar = ar_op(rhos, sigma, init_dist, steps)
666+
ar = ar_op(rhos, sigma, init_dist, steps, noise_rng)
668667
return ar
669668

670669

@@ -731,7 +730,6 @@ class GARCH11RV(SymbolicRandomVariable):
731730

732731
def update(self, node: Node):
733732
"""Return the update mapping for the noise RV."""
734-
# Since noise is a shared variable it shows up as the last node input
735733
return {node.inputs[-1]: node.outputs[0]}
736734

737735

@@ -797,7 +795,6 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
797795
# In this case the size of the init_dist depends on the parameters shape
798796
batch_size = pt.broadcast_shape(omega, alpha_1, beta_1, initial_vol)
799797
init_dist = change_dist_size(init_dist, batch_size)
800-
# initial_vol = initial_vol * pt.ones(batch_size)
801798

802799
# Create OpFromGraph representing random draws from GARCH11 process
803800
# Variables with underscore suffix are dummy inputs into the OpFromGraph
@@ -819,7 +816,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
819816

820817
(y_t, _), innov_updates_ = pytensor.scan(
821818
fn=step,
822-
outputs_info=[init_, initial_vol_ * pt.ones(batch_size)],
819+
outputs_info=[init_, pt.broadcast_to(initial_vol_.astype("floatX"), init_.shape)],
823820
non_sequences=[omega_, alpha_1_, beta_1_, noise_rng],
824821
n_steps=steps_,
825822
strict=True,
@@ -831,11 +828,11 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
831828
)
832829

833830
garch11_op = GARCH11RV(
834-
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_],
831+
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_, noise_rng],
835832
outputs=[noise_next_rng, garch11_],
836833
)
837834

838-
garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps)
835+
garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng)
839836
return garch11
840837

841838

@@ -891,14 +888,13 @@ class EulerMaruyamaRV(SymbolicRandomVariable):
891888
ndim_supp = 1
892889
_print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}")
893890

894-
def __init__(self, *args, dt, sde_fn, **kwargs):
891+
def __init__(self, *args, dt: float, sde_fn: Callable, **kwargs):
895892
self.dt = dt
896893
self.sde_fn = sde_fn
897894
super().__init__(*args, **kwargs)
898895

899896
def update(self, node: Node):
900897
"""Return the update mapping for the noise RV."""
901-
# Since noise is a shared variable it shows up as the last node input
902898
return {node.inputs[-1]: node.outputs[0]}
903899

904900

@@ -1010,14 +1006,14 @@ def step(*prev_args):
10101006
)
10111007

10121008
eulermaruyama_op = EulerMaruyamaRV(
1013-
inputs=[init_, steps_, *sde_pars_],
1009+
inputs=[init_, steps_, *sde_pars_, noise_rng],
10141010
outputs=[noise_next_rng, sde_out_],
10151011
dt=dt,
10161012
sde_fn=sde_fn,
10171013
signature=f"(),(s),{','.join('()' for _ in sde_pars_)}->(),(t)",
10181014
)
10191015

1020-
eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars)
1016+
eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars, noise_rng)
10211017
return eulermaruyama
10221018

10231019

pymc/distributions/truncated.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
6363
super().__init__(*args, **kwargs)
6464

6565
def update(self, node: Node):
66-
"""Return the update mapping for the noise RV."""
67-
# Since RNG is a shared variable it shows up as the last node input
66+
"""Return the update mapping for the internal RNG."""
6867
return {node.inputs[-1]: node.outputs[0]}
6968

7069

@@ -195,20 +194,20 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
195194
cdf_upper_ = pt.exp(logcdf(rv_, upper_))
196195
# It's okay to reuse the same rng here, because the rng in rv_ will not be
197196
# used by either the logcdf of icdf functions
198-
uniform_ = pt.random.uniform(
197+
uniform_next_rng_, uniform_ = pt.random.uniform(
199198
cdf_lower_,
200199
cdf_upper_,
201200
rng=rng,
202201
size=rv_inputs_[0],
203-
)
202+
).owner.outputs
204203
truncated_rv_ = icdf(rv_, uniform_)
205204
return TruncatedRV(
206205
base_rv_op=dist.owner.op,
207-
inputs=graph_inputs_,
208-
outputs=[uniform_.owner.outputs[0], truncated_rv_],
206+
inputs=[*graph_inputs_, rng],
207+
outputs=[uniform_next_rng_, truncated_rv_],
209208
ndim_supp=0,
210209
max_n_steps=max_n_steps,
211-
)(*graph_inputs)
210+
)(*graph_inputs, rng)
212211
except NotImplementedError:
213212
pass
214213

@@ -248,13 +247,14 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
248247
truncated_rv_, convergence_
249248
)
250249

250+
[next_rng] = updates.values()
251251
return TruncatedRV(
252252
base_rv_op=dist.owner.op,
253-
inputs=graph_inputs_,
254-
outputs=[next(iter(updates.values())), truncated_rv_],
253+
inputs=[*graph_inputs_, rng],
254+
outputs=[next_rng, truncated_rv_],
255255
ndim_supp=0,
256256
max_n_steps=max_n_steps,
257-
)(*graph_inputs)
257+
)(*graph_inputs, rng)
258258

259259

260260
@_change_dist_size.register(TruncatedRV)

pymc/pytensorf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]
801801

802802

803803
def collect_default_updates(
804-
outputs: Sequence[Variable],
804+
outputs: Variable | Sequence[Variable],
805805
*,
806806
inputs: Sequence[Variable] | None = None,
807807
must_be_shared: bool = True,

tests/distributions/test_distribution.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@
4141
CustomDist,
4242
CustomDistRV,
4343
CustomSymbolicDistRV,
44+
DiracDelta,
4445
PartialObservedRV,
4546
SymbolicRandomVariable,
4647
_support_point,
4748
create_partial_observed_rv,
4849
support_point,
4950
)
50-
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
51+
from pymc.distributions.shape_utils import change_dist_size, to_tuple
5152
from pymc.distributions.transforms import log
5253
from pymc.exceptions import BlockModelAccessError
5354
from pymc.logprob.basic import conditional_logp, logcdf, logp
@@ -584,9 +585,7 @@ def custom_dist(p, sigma, size):
584585

585586
def test_custom_methods(self):
586587
def custom_dist(mu, size):
587-
if rv_size_is_none(size):
588-
return mu
589-
return pt.full(size, mu)
588+
return DiracDelta.dist(mu, size=size)
590589

591590
def custom_support_point(rv, size, mu):
592591
return pt.full_like(rv, mu + 1)
@@ -778,7 +777,8 @@ def test_inline(self):
778777
class TestSymbolicRV(SymbolicRandomVariable):
779778
pass
780779

781-
x = TestSymbolicRV([], [Flat.dist()], ndim_supp=0)()
780+
rng = pytensor.shared(np.random.default_rng())
781+
x = TestSymbolicRV([rng], [Flat.dist(rng=rng)], ndim_supp=0)(rng)
782782

783783
# By default, the SymbolicRandomVariable will not be inlined. Because we did not
784784
# dispatch a custom logprob function it will raise next
@@ -788,7 +788,7 @@ class TestSymbolicRV(SymbolicRandomVariable):
788788
class TestInlinedSymbolicRV(SymbolicRandomVariable):
789789
inline_logprob = True
790790

791-
x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)()
791+
x_inline = TestInlinedSymbolicRV([rng], [Flat.dist(rng=rng)], ndim_supp=0)(rng)
792792
assert np.isclose(logp(x_inline, 0).eval(), 0)
793793

794794
def test_default_update(self):
@@ -826,6 +826,32 @@ def update(self, node):
826826
):
827827
compile_pymc(inputs=[], outputs=x, random_seed=431)
828828

829+
def test_recreate_with_different_rng_inputs(self):
830+
"""Test that we can recreate a SymbolicRandomVariable with new RNG inputs.
831+
832+
Related to https://github.com/pymc-devs/pytensor/issues/473
833+
"""
834+
rng = pytensor.shared(np.random.default_rng())
835+
836+
dummy_rng = rng.type()
837+
dummy_next_rng, dummy_x = pt.random.normal(rng=dummy_rng).owner.outputs
838+
839+
op = SymbolicRandomVariable(
840+
[dummy_rng],
841+
[dummy_next_rng, dummy_x],
842+
ndim_supp=0,
843+
)
844+
845+
next_rng, x = op(rng)
846+
assert op.update(x.owner) == {rng: next_rng}
847+
848+
new_rng = pytensor.shared(np.random.default_rng())
849+
inputs = x.owner.inputs.copy()
850+
inputs[0] = new_rng
851+
# This would fail with the default OpFromGraph.__call__()
852+
new_next_rng, new_x = x.owner.op(*inputs)
853+
assert op.update(new_x.owner) == {new_rng: new_next_rng}
854+
829855

830856
def test_tag_future_warning_dist():
831857
# Test no unexpected warnings

0 commit comments

Comments
 (0)