Skip to content

Commit 4702bd4

Browse files
committed
Create dispatched resize_dist
1 parent 0425312 commit 4702bd4

13 files changed

+246
-201
lines changed

pymc/aesaraf.py

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import scipy.sparse as sps
3232

3333
from aeppl.logprob import CheckParameterValue
34-
from aesara import config, scalar
34+
from aesara import scalar
3535
from aesara.compile.mode import Mode, get_mode
3636
from aesara.gradient import grad
3737
from aesara.graph import local_optimizer
@@ -45,17 +45,15 @@
4545
walk,
4646
)
4747
from aesara.graph.fg import FunctionGraph
48-
from aesara.graph.op import Op, compute_test_value
48+
from aesara.graph.op import Op
4949
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
5050
from aesara.scalar.basic import Cast
5151
from aesara.tensor.elemwise import Elemwise
5252
from aesara.tensor.random.op import RandomVariable
53-
from aesara.tensor.shape import SpecifyShape
5453
from aesara.tensor.sharedvar import SharedVariable
5554
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
5655
from aesara.tensor.var import TensorConstant, TensorVariable
5756

58-
from pymc.exceptions import ShapeError
5957
from pymc.vartypes import continuous_types, int_types, isgenerator, typefilter
6058

6159
PotentialShapeType = Union[
@@ -142,64 +140,6 @@ def pandas_to_array(data):
142140
return floatX(ret)
143141

144142

145-
def change_rv_size(
146-
rv: TensorVariable,
147-
new_size: PotentialShapeType,
148-
expand: Optional[bool] = False,
149-
) -> TensorVariable:
150-
"""Change or expand the size of a `RandomVariable`.
151-
152-
Parameters
153-
==========
154-
rv
155-
The old `RandomVariable` output.
156-
new_size
157-
The new size.
158-
expand:
159-
Expand the existing size by `new_size`.
160-
161-
"""
162-
# Check the dimensionality of the `new_size` kwarg
163-
new_size_ndim = np.ndim(new_size)
164-
if new_size_ndim > 1:
165-
raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim)
166-
elif new_size_ndim == 0:
167-
new_size = (new_size,)
168-
169-
# Extract the RV node that is to be resized, together with its inputs, name and tag
170-
if isinstance(rv.owner.op, SpecifyShape):
171-
rv = rv.owner.inputs[0]
172-
rv_node = rv.owner
173-
rng, size, dtype, *dist_params = rv_node.inputs
174-
name = rv.name
175-
tag = rv.tag
176-
177-
if expand:
178-
shape = tuple(rv_node.op._infer_shape(size, dist_params))
179-
size = shape[: len(shape) - rv_node.op.ndim_supp]
180-
new_size = tuple(new_size) + tuple(size)
181-
182-
# Make sure the new size is a tensor. This dtype-aware conversion helps
183-
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
184-
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
185-
186-
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
187-
new_rv = new_rv_node.outputs[-1]
188-
new_rv.name = name
189-
for k, v in tag.__dict__.items():
190-
new_rv.tag.__dict__.setdefault(k, v)
191-
192-
# Update "traditional" rng default_update, if that was set for old RV
193-
default_update = getattr(rng, "default_update", None)
194-
if default_update is not None and default_update is rv_node.outputs[0]:
195-
rng.default_update = new_rv_node.outputs[0]
196-
197-
if config.compute_test_value != "off":
198-
compute_test_value(new_rv_node)
199-
200-
return new_rv
201-
202-
203143
def extract_rv_and_value_vars(
204144
var: TensorVariable,
205145
) -> Tuple[TensorVariable, TensorVariable]:

pymc/distributions/censored.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
from aesara.tensor.random.op import RandomVariable
2020

2121
from pymc.distributions.distribution import SymbolicDistribution, _moment
22-
from pymc.distributions.shape_utils import _ndim_supp_dist, ndim_supp_dist
22+
from pymc.distributions.shape_utils import (
23+
_ndim_supp_dist,
24+
_resize_dist,
25+
ndim_supp_dist,
26+
resize_dist,
27+
)
2328
from pymc.util import check_dist_not_registered
2429

2530

@@ -90,22 +95,12 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
9095
rv_out.tag.upper = upper
9196

9297
if size is not None:
93-
rv_out = cls.change_size(rv_out, size)
98+
rv_out = resize_dist(rv_out, size)
9499
if rngs is not None:
95100
rv_out = cls.change_rngs(rv_out, rngs)
96101

97102
return rv_out
98103

99-
@classmethod
100-
def change_size(cls, rv, new_size, expand=False):
101-
dist_node = rv.tag.dist.owner
102-
lower = rv.tag.lower
103-
upper = rv.tag.upper
104-
rng, old_size, dtype, *dist_params = dist_node.inputs
105-
new_size = new_size if not expand else tuple(new_size) + tuple(old_size)
106-
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
107-
return cls.rv_op(new_dist, lower, upper)
108-
109104
@classmethod
110105
def change_rngs(cls, rv, new_rngs):
111106
(new_rng,) = new_rngs
@@ -127,6 +122,15 @@ def ndim_supp_censored(op, dist):
127122
return 0
128123

129124

125+
@_resize_dist.register(Clip)
126+
def resize_censored(op, rv, new_size, expand=False):
127+
dist = rv.tag.dist
128+
lower = rv.tag.lower
129+
upper = rv.tag.upper
130+
new_dist = resize_dist(dist, new_size, expand=expand)
131+
return Censored.rv_op(new_dist, lower, upper)
132+
133+
130134
@_moment.register(Clip)
131135
def moment_censored(op, rv, dist, lower, upper):
132136
moment = at.switch(

pymc/distributions/distribution.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from aesara.tensor.var import TensorVariable
3434
from typing_extensions import TypeAlias
3535

36-
from pymc.aesaraf import change_rv_size
3736
from pymc.distributions.shape_utils import (
3837
Dims,
3938
Shape,
@@ -45,6 +44,7 @@
4544
convert_size,
4645
find_size,
4746
ndim_supp_dist,
47+
resize_dist,
4848
resize_from_dims,
4949
resize_from_observed,
5050
)
@@ -270,7 +270,7 @@ def __new__(
270270

271271
if resize_shape:
272272
# A batch size was specified through `dims`, or implied by `observed`.
273-
rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)
273+
rv_out = resize_dist(dist=rv_out, new_size=resize_shape, expand=True)
274274

275275
rv_out = model.register_rv(
276276
rv_out,
@@ -356,7 +356,7 @@ def dist(
356356
# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
357357
if shape is not None and Ellipsis in shape:
358358
replicate_shape = cast(StrongShape, shape[:-1])
359-
rv_out = change_rv_size(rv=rv_out, new_size=replicate_shape, expand=True)
359+
rv_out = resize_dist(dist=rv_out, new_size=replicate_shape, expand=True)
360360

361361
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
362362
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
@@ -400,9 +400,6 @@ def __new__(
400400
cls.rv_op
401401
Returns a TensorVariable that represents the symbolic distribution
402402
parametrized by a default set of parameters and a size and rngs arguments
403-
cls.change_size
404-
Returns an equivalent symbolic distribution with a different size. This is
405-
analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s.
406403
cls.graph_rvs
407404
Returns base RVs in a symbolic distribution.
408405
@@ -413,6 +410,9 @@ def __new__(
413410
constant, for instance if the symbolic distribution can be defined based
414411
on an arbitrary base distribution. This is called by
415412
`pymc.distributions.shape_utils.ndim_supp_dist`
413+
_resize_dist
414+
Returns an equivalent symbolic distribution with a different size. This is
415+
called by `pymc.distrributions.shape_utils.resize_dist`.
416416
417417
Parameters
418418
----------

pymc/distributions/mixture.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,18 @@
2525
from aesara.tensor import TensorVariable
2626
from aesara.tensor.random.op import RandomVariable
2727

28-
from pymc.aesaraf import change_rv_size
2928
from pymc.distributions import transforms
3029
from pymc.distributions.continuous import Normal, get_tau_sigma
3130
from pymc.distributions.dist_math import check_parameters
3231
from pymc.distributions.distribution import SymbolicDistribution, _moment, moment
3332
from pymc.distributions.logprob import logcdf, logp
34-
from pymc.distributions.shape_utils import _ndim_supp_dist, ndim_supp_dist, to_tuple
33+
from pymc.distributions.shape_utils import (
34+
_ndim_supp_dist,
35+
_resize_dist,
36+
ndim_supp_dist,
37+
resize_dist,
38+
to_tuple,
39+
)
3540
from pymc.distributions.transforms import _default_transform
3641
from pymc.util import check_dist_not_registered
3742
from pymc.vartypes import continuous_types, discrete_types
@@ -323,27 +328,7 @@ def _resize_components(cls, size, *components):
323328
mix_size = components[0].shape[mix_axis]
324329
size = tuple(size) + (mix_size,)
325330

326-
return [change_rv_size(component, size) for component in components]
327-
328-
@classmethod
329-
def change_size(cls, rv, new_size, expand=False):
330-
mix_indexes_rng, weights, *components = rv.owner.inputs
331-
rngs = [component.owner.inputs[0] for component in components] + [mix_indexes_rng]
332-
333-
if expand:
334-
component = components[0]
335-
# Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0`
336-
size_dims = component.ndim - ndim_supp_dist(component)
337-
if len(components) == 1:
338-
# If we have a single component, new size should ignore the mixture axis
339-
# dimension, as that is not touched by `_resize_components`
340-
size_dims -= 1
341-
old_size = components[0].shape[:size_dims]
342-
new_size = to_tuple(new_size) + tuple(old_size)
343-
344-
components = cls._resize_components(new_size, *components)
345-
346-
return cls.rv_op(weights, *components, rngs=rngs, size=None)
331+
return [resize_dist(component, size) for component in components]
347332

348333
@classmethod
349334
def graph_rvs(cls, rv):
@@ -361,6 +346,27 @@ def ndim_supp_marginal_mixture(op, rv):
361346
return ndim_supp_dist(components[0])
362347

363348

349+
@_resize_dist.register(MarginalMixtureRV)
350+
def resize_marginal_mixture(op, rv, new_size, expand=False):
351+
mix_indexes_rng, weights, *components = rv.owner.inputs
352+
rngs = [component.owner.inputs[0] for component in components] + [mix_indexes_rng]
353+
354+
if expand:
355+
component = components[0]
356+
# Old size is equal to `shape[:-ndim_supp]`, with care needed for `ndim_supp == 0`
357+
size_dims = component.ndim - ndim_supp_dist(component)
358+
if len(components) == 1:
359+
# If we have a single component, new size should ignore the mixture axis
360+
# dimension, as that is not touched by `_resize_components`
361+
size_dims -= 1
362+
old_size = components[0].shape[:size_dims]
363+
new_size = to_tuple(new_size) + tuple(old_size)
364+
365+
components = Mixture._resize_components(new_size, *components)
366+
367+
return Mixture.rv_op(weights, *components, rngs=rngs, size=None)
368+
369+
364370
@_get_measurable_outputs.register(MarginalMixtureRV)
365371
def _get_measurable_outputs_MarginalMixtureRV(op, node):
366372
# This tells Aeppl that the second output is the measurable one

pymc/distributions/multivariate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import pymc as pm
4343

44-
from pymc.aesaraf import change_rv_size, floatX, intX
44+
from pymc.aesaraf import floatX, intX
4545
from pymc.distributions import transforms
4646
from pymc.distributions.continuous import (
4747
BoundedContinuous,
@@ -60,6 +60,7 @@
6060
from pymc.distributions.shape_utils import (
6161
broadcast_dist_samples_to,
6262
ndim_supp_dist,
63+
resize_dist,
6364
rv_size_is_none,
6465
to_tuple,
6566
)
@@ -1199,10 +1200,10 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
11991200
# Since `eta` and `n` are forced to be scalars we don't need to worry about
12001201
# implied batched dimensions for the time being.
12011202
if ndim_supp_dist(sd_dist) == 0:
1202-
sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,))
1203+
sd_dist = resize_dist(sd_dist, to_tuple(size) + (n,))
12031204
else:
12041205
# The support shape must be `n` but we have no way of controlling it
1205-
sd_dist = change_rv_size(sd_dist, to_tuple(size))
1206+
sd_dist = resize_dist(sd_dist, to_tuple(size))
12061207

12071208
# sd_dist is part of the generative graph, but should be completely ignored
12081209
# by the logp graph, since the LKJ logp explicitly includes these terms.

0 commit comments

Comments
 (0)