Skip to content

Commit 0425312

Browse files
committed
Create dispatched ndim_supp_dist
1 parent 4c720cb commit 0425312

File tree

7 files changed

+72
-28
lines changed

7 files changed

+72
-28
lines changed

pymc/distributions/censored.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from aesara.tensor.random.op import RandomVariable
2020

2121
from pymc.distributions.distribution import SymbolicDistribution, _moment
22+
from pymc.distributions.shape_utils import _ndim_supp_dist, ndim_supp_dist
2223
from pymc.util import check_dist_not_registered
2324

2425

@@ -65,7 +66,7 @@ def dist(cls, dist, lower, upper, **kwargs):
6566
raise ValueError(
6667
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
6768
)
68-
if dist.owner.op.ndim_supp > 0:
69+
if ndim_supp_dist(dist) > 0:
6970
raise NotImplementedError(
7071
"Censoring of multivariate distributions has not been implemented yet"
7172
)
@@ -95,10 +96,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
9596

9697
return rv_out
9798

98-
@classmethod
99-
def ndim_supp(cls, *dist_params):
100-
return 0
101-
10299
@classmethod
103100
def change_size(cls, rv, new_size, expand=False):
104101
dist_node = rv.tag.dist.owner
@@ -124,6 +121,12 @@ def graph_rvs(cls, rv):
124121
return (rv.tag.dist,)
125122

126123

124+
@_ndim_supp_dist.register(Clip)
125+
def ndim_supp_censored(op, dist):
126+
# We only support Censoring of univariate distributions
127+
return 0
128+
129+
127130
@_moment.register(Clip)
128131
def moment_censored(op, rv, dist, lower, upper):
129132
moment = at.switch(

pymc/distributions/distribution.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
convert_shape,
4545
convert_size,
4646
find_size,
47+
ndim_supp_dist,
4748
resize_from_dims,
4849
resize_from_observed,
4950
)
@@ -399,16 +400,20 @@ def __new__(
399400
cls.rv_op
400401
Returns a TensorVariable that represents the symbolic distribution
401402
parametrized by a default set of parameters and a size and rngs arguments
402-
cls.ndim_supp
403-
Returns the support of the symbolic distribution, given the default
404-
parameters. This may not always be constant, for instance if the symbolic
405-
distribution can be defined based on an arbitrary base distribution.
406403
cls.change_size
407404
Returns an equivalent symbolic distribution with a different size. This is
408405
analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s.
409406
cls.graph_rvs
410407
Returns base RVs in a symbolic distribution.
411408
409+
Furthermore, Censored distributions must have a dispatch version of the following
410+
functions for correct behavior in PyMC:
411+
_ndim_supp_dist
412+
Returns the support of the symbolic distribution. This may not always be
413+
constant, for instance if the symbolic distribution can be defined based
414+
on an arbitrary base distribution. This is called by
415+
`pymc.distributions.shape_utils.ndim_supp_dist`
416+
412417
Parameters
413418
----------
414419
cls : type
@@ -559,8 +564,11 @@ def dist(
559564
shape = convert_shape(shape)
560565
size = convert_size(size)
561566

567+
# Create a temporary dist to obtain the ndim_supp
568+
ndim_supp = ndim_supp_dist(cls.rv_op(*dist_params, size=size))
569+
562570
create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
563-
shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params)
571+
shape=shape, size=size, ndim_supp=ndim_supp
564572
)
565573
# Create the RV with a `size` right away.
566574
# This is not necessarily the final result.

pymc/distributions/mixture.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pymc.distributions.dist_math import check_parameters
3232
from pymc.distributions.distribution import SymbolicDistribution, _moment, moment
3333
from pymc.distributions.logprob import logcdf, logp
34-
from pymc.distributions.shape_utils import to_tuple
34+
from pymc.distributions.shape_utils import _ndim_supp_dist, ndim_supp_dist, to_tuple
3535
from pymc.distributions.transforms import _default_transform
3636
from pymc.util import check_dist_not_registered
3737
from pymc.vartypes import continuous_types, discrete_types
@@ -188,7 +188,7 @@ def dist(cls, w, comp_dists, **kwargs):
188188
f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}"
189189
)
190190
check_dist_not_registered(dist)
191-
components_ndim_supp.add(dist.owner.op.ndim_supp)
191+
components_ndim_supp.add(ndim_supp_dist(dist))
192192

193193
if len(components_ndim_supp) > 1:
194194
raise ValueError(
@@ -209,7 +209,7 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
209209
mix_indexes_rng = aesara.shared(np.random.default_rng())
210210

211211
single_component = len(components) == 1
212-
ndim_supp = components[0].owner.op.ndim_supp
212+
ndim_supp = ndim_supp_dist(components[0])
213213

214214
if size is not None:
215215
components = cls._resize_components(size, *components)
@@ -319,17 +319,12 @@ def _resize_components(cls, size, *components):
319319
if len(components) == 1:
320320
# If we have a single component, we need to keep the length of the mixture
321321
# axis intact, because that's what determines the number of mixture components
322-
mix_axis = -components[0].owner.op.ndim_supp - 1
322+
mix_axis = -ndim_supp_dist(components[0]) - 1
323323
mix_size = components[0].shape[mix_axis]
324324
size = tuple(size) + (mix_size,)
325325

326326
return [change_rv_size(component, size) for component in components]
327327

328-
@classmethod
329-
def ndim_supp(cls, weights, *components):
330-
# We already checked that all components have the same support dimensionality
331-
return components[0].owner.op.ndim_supp
332-
333328
@classmethod
334329
def change_size(cls, rv, new_size, expand=False):
335330
mix_indexes_rng, weights, *components = rv.owner.inputs
@@ -338,7 +333,7 @@ def change_size(cls, rv, new_size, expand=False):
338333
if expand:
339334
component = components[0]
340335
# Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0`
341-
size_dims = component.ndim - component.owner.op.ndim_supp
336+
size_dims = component.ndim - ndim_supp_dist(component)
342337
if len(components) == 1:
343338
# If we have a single component, new size should ignore the mixture axis
344339
# dimension, as that is not touched by `_resize_components`
@@ -359,6 +354,13 @@ def graph_rvs(cls, rv):
359354
return (*rv.owner.inputs[2:], rv)
360355

361356

357+
@_ndim_supp_dist.register(MarginalMixtureRV)
358+
def ndim_supp_marginal_mixture(op, rv):
359+
# We already checked that all components have the same support dimensionality
360+
components = rv.owner.inputs[2:]
361+
return ndim_supp_dist(components[0])
362+
363+
362364
@_get_measurable_outputs.register(MarginalMixtureRV)
363365
def _get_measurable_outputs_MarginalMixtureRV(op, node):
364366
# This tells Aeppl that the second output is the measurable one
@@ -372,7 +374,7 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
372374
# single component
373375
if len(components) == 1:
374376
# Need to broadcast value across mixture axis
375-
mix_axis = -components[0].owner.op.ndim_supp - 1
377+
mix_axis = -ndim_supp_dist(components[0]) - 1
376378
components_logp = logp(components[0], at.expand_dims(value, mix_axis))
377379
else:
378380
components_logp = at.stack(
@@ -405,7 +407,7 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs):
405407
# single component
406408
if len(components) == 1:
407409
# Need to broadcast value across mixture axis
408-
mix_axis = -components[0].owner.op.ndim_supp - 1
410+
mix_axis = -ndim_supp_dist(components[0]) - 1
409411
components_logcdf = logcdf(components[0], at.expand_dims(value, mix_axis))
410412
else:
411413
components_logcdf = at.stack(
@@ -434,7 +436,7 @@ def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs):
434436

435437
@_moment.register(MarginalMixtureRV)
436438
def marginal_mixture_moment(op, rv, rng, weights, *components):
437-
ndim_supp = components[0].owner.op.ndim_supp
439+
ndim_supp = ndim_supp_dist(components[0])
438440
weights = at.shape_padright(weights, ndim_supp)
439441
mix_axis = -ndim_supp - 1
440442

pymc/distributions/multivariate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from pymc.distributions.distribution import Continuous, Discrete, moment
6060
from pymc.distributions.shape_utils import (
6161
broadcast_dist_samples_to,
62+
ndim_supp_dist,
6263
rv_size_is_none,
6364
to_tuple,
6465
)
@@ -1187,7 +1188,7 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
11871188
isinstance(sd_dist, Variable)
11881189
and sd_dist.owner is not None
11891190
and isinstance(sd_dist.owner.op, RandomVariable)
1190-
and sd_dist.owner.op.ndim_supp < 2
1191+
and ndim_supp_dist(sd_dist) < 2
11911192
):
11921193
raise TypeError("sd_dist must be a scalar or vector distribution variable")
11931194

@@ -1197,7 +1198,7 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
11971198
# diagonal element.
11981199
# Since `eta` and `n` are forced to be scalars we don't need to worry about
11991200
# implied batched dimensions for the time being.
1200-
if sd_dist.owner.op.ndim_supp == 0:
1201+
if ndim_supp_dist(sd_dist) == 0:
12011202
sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,))
12021203
else:
12031204
# The support shape must be `n` but we have no way of controlling it

pymc/distributions/shape_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
A collection of common shape operations needed for broadcasting
1818
samples from probability distributions for stochastic nodes in PyMC.
1919
"""
20-
20+
from functools import singledispatch
2121
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union, cast
2222

2323
import numpy as np
2424

2525
from aesara.graph.basic import Variable
26+
from aesara.graph.op import Op
27+
from aesara.tensor.elemwise import Elemwise
28+
from aesara.tensor.random.op import RandomVariable
2629
from aesara.tensor.var import TensorVariable
2730
from typing_extensions import TypeAlias
2831

@@ -619,3 +622,23 @@ def find_size(
619622
def rv_size_is_none(size: Variable) -> bool:
620623
"""Check wether an rv size is None (ie., at.Constant([]))"""
621624
return size.type.shape == (0,) # type: ignore [attr-defined]
625+
626+
627+
@singledispatch
628+
def _ndim_supp_dist(op: Op, dist: TensorVariable) -> int:
629+
raise TypeError(f"ndim_supp not known for Op {op}")
630+
631+
632+
def ndim_supp_dist(dist: TensorVariable) -> int:
633+
return _ndim_supp_dist(dist.owner.op, dist)
634+
635+
636+
@_ndim_supp_dist.register(RandomVariable)
637+
def ndim_supp_rv(op: Op, rv: TensorVariable):
638+
return op.ndim_supp
639+
640+
641+
@_ndim_supp_dist.register(Elemwise)
642+
def ndim_supp_elemwise(op: Op, *args, **kwargs):
643+
"""For Elemwise Ops, dispatch on respective scalar_op"""
644+
return _ndim_supp_dist(op.scalar_op, *args, **kwargs)

pymc/distributions/timeseries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pymc.distributions import distribution, logprob, multivariate
2525
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
2626
from pymc.distributions.dist_math import check_parameters
27-
from pymc.distributions.shape_utils import to_tuple
27+
from pymc.distributions.shape_utils import ndim_supp_dist, to_tuple
2828
from pymc.util import check_dist_not_registered
2929

3030
__all__ = [
@@ -175,7 +175,7 @@ def dist(
175175
isinstance(init, at.TensorVariable)
176176
and init.owner is not None
177177
and isinstance(init.owner.op, RandomVariable)
178-
and init.owner.op.ndim_supp == 0
178+
and ndim_supp_dist(init) == 0
179179
):
180180
raise TypeError("init must be a univariate distribution variable")
181181

pymc/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from aesara.compile.sharedvalue import SharedVariable
4444
from aesara.graph.basic import Constant, Variable, graph_inputs
4545
from aesara.graph.fg import FunctionGraph
46+
from aesara.tensor.random.op import RandomVariable
4647
from aesara.tensor.random.opt import local_subtensor_rv_lift
4748
from aesara.tensor.random.var import RandomStateSharedVariable
4849
from aesara.tensor.sharedvar import ScalarSharedVariable
@@ -1330,6 +1331,12 @@ def make_obs_var(
13301331
)
13311332
warnings.warn(impute_message, ImputationWarning)
13321333

1334+
# TODO: Add test for this
1335+
if not isinstance(rv_var.owner.op, RandomVariable):
1336+
raise NotImplementedError(
1337+
f"Automatic inputation is only supported for RandomVariables, but {rv_var} is of type {rv_var.owner.op}"
1338+
)
1339+
13331340
if rv_var.owner.op.ndim_supp > 0:
13341341
raise NotImplementedError(
13351342
f"Automatic inputation is only supported for univariate RandomVariables, but {rv_var} is multivariate"

0 commit comments

Comments
 (0)