Skip to content

Commit 9155922

Browse files
Re-enable passing dims alongside shape or size (#5325)
* Remove unused return value from helper functions * Restore support for passing `dims` alongside `shape` or `size` * Extract RV creation and `resize_shape` determination code Closes #4656
1 parent 6570e95 commit 9155922

File tree

3 files changed

+96
-75
lines changed

3 files changed

+96
-75
lines changed

pymc/distributions/distribution.py

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919

2020
from abc import ABCMeta
2121
from functools import singledispatch
22-
from typing import Callable, Iterable, Optional, Sequence
22+
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
2323

2424
import aesara
25+
import numpy as np
2526

2627
from aeppl.logprob import _logcdf, _logprob
2728
from aesara import tensor as at
29+
from aesara.graph.basic import Variable
2830
from aesara.tensor.basic import as_tensor_variable
2931
from aesara.tensor.elemwise import Elemwise
3032
from aesara.tensor.random.op import RandomVariable
@@ -36,6 +38,8 @@
3638
Dims,
3739
Shape,
3840
Size,
41+
StrongShape,
42+
WeakDims,
3943
convert_dims,
4044
convert_shape,
4145
convert_size,
@@ -133,6 +137,37 @@ def fn(*args, **kwargs):
133137
return fn
134138

135139

140+
def _make_rv_and_resize_shape(
141+
*,
142+
cls,
143+
dims: Optional[Dims],
144+
model,
145+
observed,
146+
args,
147+
**kwargs,
148+
) -> Tuple[Variable, Optional[WeakDims], Optional[Union[np.ndarray, Variable]], StrongShape]:
149+
"""Creates the RV and processes dims or observed to determine a resize shape."""
150+
# Create the RV without dims information, because that's not something tracked at the Aesara level.
151+
# If necessary we'll later replicate to a different size implied by already known dims.
152+
rv_out = cls.dist(*args, **kwargs)
153+
ndim_actual = rv_out.ndim
154+
resize_shape = None
155+
156+
# # `dims` are only available with this API, because `.dist()` can be used
157+
# # without a modelcontext and dims are not tracked at the Aesara level.
158+
dims = convert_dims(dims)
159+
dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None
160+
if dims is not None:
161+
if dims_can_resize:
162+
resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
163+
elif Ellipsis in dims:
164+
# Replace ... with None entries to match the actual dimensionality.
165+
dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual]
166+
elif observed is not None:
167+
resize_shape, observed = resize_from_observed(observed, ndim_actual)
168+
return rv_out, dims, observed, resize_shape
169+
170+
136171
class Distribution(metaclass=DistributionMeta):
137172
"""Statistical distribution"""
138173

@@ -213,28 +248,11 @@ def __new__(
213248
if rng is None:
214249
rng = model.next_rng()
215250

216-
if dims is not None and "shape" in kwargs:
217-
raise ValueError(
218-
f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!"
219-
)
220-
if dims is not None and "size" in kwargs:
221-
raise ValueError(
222-
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
223-
)
224-
dims = convert_dims(dims)
225-
226-
# Create the RV without dims information, because that's not something tracked at the Aesara level.
227-
# If necessary we'll later replicate to a different size implied by already known dims.
228-
rv_out = cls.dist(*args, rng=rng, **kwargs)
229-
ndim_actual = rv_out.ndim
230-
resize_shape = None
231-
232-
# `dims` are only available with this API, because `.dist()` can be used
233-
# without a modelcontext and dims are not tracked at the Aesara level.
234-
if dims is not None:
235-
ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
236-
elif observed is not None:
237-
ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual)
251+
# Create the RV and process dims and observed to determine
252+
# a shape by which the created RV may need to be resized.
253+
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
254+
cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs
255+
)
238256

239257
if resize_shape:
240258
# A batch size was specified through `dims`, or implied by `observed`.
@@ -456,35 +474,18 @@ def __new__(
456474
if not isinstance(name, string_types):
457475
raise TypeError(f"Name needs to be a string but got: {name}")
458476

459-
if dims is not None and "shape" in kwargs:
460-
raise ValueError(
461-
f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!"
462-
)
463-
if dims is not None and "size" in kwargs:
464-
raise ValueError(
465-
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
466-
)
467-
dims = convert_dims(dims)
468-
469477
if rngs is None:
470478
# Create a temporary rv to obtain number of rngs needed
471479
temp_graph = cls.dist(*args, rngs=None, **kwargs)
472480
rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)]
473481
elif not isinstance(rngs, (list, tuple)):
474482
rngs = [rngs]
475483

476-
# Create the RV without dims information, because that's not something tracked at the Aesara level.
477-
# If necessary we'll later replicate to a different size implied by already known dims.
478-
rv_out = cls.dist(*args, rngs=rngs, **kwargs)
479-
ndim_actual = rv_out.ndim
480-
resize_shape = None
481-
482-
# # `dims` are only available with this API, because `.dist()` can be used
483-
# # without a modelcontext and dims are not tracked at the Aesara level.
484-
if dims is not None:
485-
ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
486-
elif observed is not None:
487-
ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual)
484+
# Create the RV and process dims and observed to determine
485+
# a shape by which the created RV may need to be resized.
486+
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
487+
cls=cls, dims=dims, model=model, observed=observed, args=args, rngs=rngs, **kwargs
488+
)
488489

489490
if resize_shape:
490491
# A batch size was specified through `dims`, or implied by `observed`.

pymc/distributions/shape_utils.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def broadcast_dist_samples_to(to_shape, samples, size=None):
427427
StrongSize = Union[TensorVariable, Tuple[Union[int, TensorVariable], ...]]
428428

429429

430-
def convert_dims(dims: Dims) -> Optional[WeakDims]:
430+
def convert_dims(dims: Optional[Dims]) -> Optional[WeakDims]:
431431
"""Process a user-provided dims variable into None or a valid dims tuple."""
432432
if dims is None:
433433
return None
@@ -487,9 +487,7 @@ def convert_size(size: Size) -> Optional[StrongSize]:
487487
return size
488488

489489

490-
def resize_from_dims(
491-
dims: WeakDims, ndim_implied: int, model
492-
) -> Tuple[int, StrongSize, StrongDims]:
490+
def resize_from_dims(dims: WeakDims, ndim_implied: int, model) -> Tuple[StrongSize, StrongDims]:
493491
"""Determines a potential resize shape from a `dims` tuple.
494492
495493
Parameters
@@ -503,10 +501,10 @@ def resize_from_dims(
503501
504502
Returns
505503
-------
506-
ndim_resize : int
507-
Number of dimensions that should be added through resizing.
508504
resize_shape : array-like
509-
The shape of the new dimensions.
505+
Shape of new dimensions that should be prepended.
506+
dims : tuple of (str or None)
507+
Names or None for all dimensions after resizing.
510508
"""
511509
if Ellipsis in dims:
512510
# Auto-complete the dims tuple to the full length.
@@ -525,12 +523,12 @@ def resize_from_dims(
525523

526524
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
527525
resize_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_resize])
528-
return ndim_resize, resize_shape, dims
526+
return resize_shape, dims
529527

530528

531529
def resize_from_observed(
532530
observed, ndim_implied: int
533-
) -> Tuple[int, StrongSize, Union[np.ndarray, Variable]]:
531+
) -> Tuple[StrongSize, Union[np.ndarray, Variable]]:
534532
"""Determines a potential resize shape from observations.
535533
536534
Parameters
@@ -542,18 +540,16 @@ def resize_from_observed(
542540
543541
Returns
544542
-------
545-
ndim_resize : int
546-
Number of dimensions that should be added through resizing.
547543
resize_shape : array-like
548-
The shape of the new dimensions.
544+
Shape of new dimensions that should be prepended.
549545
observed : scalar, array-like
550546
Observations as numpy array or `Variable`.
551547
"""
552548
if not hasattr(observed, "shape"):
553549
observed = pandas_to_array(observed)
554550
ndim_resize = observed.ndim - ndim_implied
555551
resize_shape = tuple(observed.shape[d] for d in range(ndim_resize))
556-
return ndim_resize, resize_shape, observed
552+
return resize_shape, observed
557553

558554

559555
def find_size(shape=None, size=None, ndim_supp=None):

pymc/tests/test_shape_handling.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,47 @@ def test_param_and_batch_shape_combos(
293293
else:
294294
raise NotImplementedError("Invalid test case parametrization.")
295295

296+
@pytest.mark.parametrize("ellipsis_in", ["none", "shape", "dims", "both"])
297+
def test_simultaneous_shape_and_dims(self, ellipsis_in):
298+
with pm.Model() as pmodel:
299+
x = pm.ConstantData("x", [1, 2, 3], dims="ddata")
300+
301+
if ellipsis_in == "none":
302+
# The shape and dims tuples correspond to each other.
303+
# Note: No checks are performed that implied shape (x), shape and dims actually match.
304+
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", "ddata"))
305+
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
306+
elif ellipsis_in == "shape":
307+
y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", "ddata"))
308+
assert pmodel.RV_dims["y"] == ("dshape", "ddata")
309+
elif ellipsis_in == "dims":
310+
y = pm.Normal("y", mu=x, shape=(2, 3), dims=("dshape", ...))
311+
assert pmodel.RV_dims["y"] == ("dshape", None)
312+
elif ellipsis_in == "both":
313+
y = pm.Normal("y", mu=x, shape=(2, ...), dims=("dshape", ...))
314+
assert pmodel.RV_dims["y"] == ("dshape", None)
315+
316+
assert "dshape" in pmodel.dim_lengths
317+
assert y.eval().shape == (2, 3)
318+
319+
@pytest.mark.parametrize("with_dims_ellipsis", [False, True])
320+
def test_simultaneous_size_and_dims(self, with_dims_ellipsis):
321+
with pm.Model() as pmodel:
322+
x = pm.ConstantData("x", [1, 2, 3], dims="ddata")
323+
assert "ddata" in pmodel.dim_lengths
324+
325+
# Size does not include support dims, so this test must use a dist with support dims.
326+
kwargs = dict(name="y", size=2, mu=at.ones((3, 4)), cov=at.eye(4))
327+
if with_dims_ellipsis:
328+
y = pm.MvNormal(**kwargs, dims=("dsize", ...))
329+
assert pmodel.RV_dims["y"] == ("dsize", None, None)
330+
else:
331+
y = pm.MvNormal(**kwargs, dims=("dsize", "ddata", "dsupport"))
332+
assert pmodel.RV_dims["y"] == ("dsize", "ddata", "dsupport")
333+
334+
assert "dsize" in pmodel.dim_lengths
335+
assert y.eval().shape == (2, 3, 4)
336+
296337
def test_define_dims_on_the_fly(self):
297338
with pm.Model() as pmodel:
298339
agedata = aesara.shared(np.array([10, 20, 30]))
@@ -312,17 +353,6 @@ def test_define_dims_on_the_fly(self):
312353
# The change should propagate all the way through
313354
assert effect.eval().shape == (4,)
314355

315-
@pytest.mark.xfail(reason="Simultaneous use of size and dims is not implemented")
316-
def test_data_defined_size_dimension_can_register_dimname(self):
317-
with pm.Model() as pmodel:
318-
x = pm.ConstantData("x", [[1, 2, 3, 4]], dims=("first", "second"))
319-
assert "first" in pmodel.dim_lengths
320-
assert "second" in pmodel.dim_lengths
321-
# two dimensions are implied; a "third" dimension is created
322-
y = pm.Normal("y", mu=x, size=2, dims=("third", "first", "second"))
323-
assert "third" in pmodel.dim_lengths
324-
assert y.eval().shape() == (2, 1, 4)
325-
326356
def test_can_resize_data_defined_size(self):
327357
with pm.Model() as pmodel:
328358
x = pm.MutableData("x", [[1, 2, 3, 4]], dims=("first", "second"))
@@ -447,9 +477,3 @@ def test_lazy_flavors(self):
447477
def test_invalid_flavors(self):
448478
with pytest.raises(ValueError, match="Passing both"):
449479
pm.Normal.dist(0, 1, shape=(3,), size=(3,))
450-
451-
with pm.Model():
452-
with pytest.raises(ValueError, match="Passing both"):
453-
pm.Normal("n", shape=(2,), dims=("town",))
454-
with pytest.raises(ValueError, match="Passing both"):
455-
pm.Normal("n", dims=("town",), size=(2,))

0 commit comments

Comments
 (0)