Skip to content

Commit 6f69a91

Browse files
Remove maybe_resize helper function and fix some type hints
1 parent 10cbfac commit 6f69a91

File tree

3 files changed

+32
-122
lines changed

3 files changed

+32
-122
lines changed

pymc/distributions/distribution.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from abc import ABCMeta
2121
from functools import singledispatch
22-
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
22+
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union, cast
2323

2424
import aesara
2525
import numpy as np
@@ -45,7 +45,6 @@
4545
convert_shape,
4646
convert_size,
4747
find_size,
48-
maybe_resize,
4948
resize_from_dims,
5049
resize_from_observed,
5150
)
@@ -353,17 +352,11 @@ def dist(
353352
# Create the RV with a `size` right away.
354353
# This is not necessarily the final result.
355354
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
356-
rv_out = maybe_resize(
357-
rv_out,
358-
cls.rv_op,
359-
dist_params,
360-
ndim_expected,
361-
ndim_batch,
362-
ndim_supp,
363-
shape,
364-
size,
365-
**kwargs,
366-
)
355+
356+
# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
357+
if shape is not None and Ellipsis in shape:
358+
replicate_shape = cast(StrongShape, shape[:-1])
359+
rv_out = change_rv_size(rv_var=rv_out, new_size=replicate_shape, expand=True)
367360

368361
rng = kwargs.pop("rng", None)
369362
if (
@@ -589,18 +582,11 @@ def dist(
589582
# Create the RV with a `size` right away.
590583
# This is not necessarily the final result.
591584
graph = cls.rv_op(*dist_params, size=create_size, **kwargs)
592-
graph = maybe_resize(
593-
graph,
594-
cls.rv_op,
595-
dist_params,
596-
ndim_expected,
597-
ndim_batch,
598-
ndim_supp,
599-
shape,
600-
size,
601-
change_rv_size_fn=cls.change_size,
602-
**kwargs,
603-
)
585+
586+
# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
587+
if shape is not None and Ellipsis in shape:
588+
replicate_shape = cast(StrongShape, shape[:-1])
589+
graph = cls.change_size(rv=graph, new_size=replicate_shape, expand=True)
604590

605591
rngs = kwargs.pop("rngs", None)
606592
if rngs is not None:

pymc/distributions/shape_utils.py

Lines changed: 20 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,15 @@
1818
samples from probability distributions for stochastic nodes in PyMC.
1919
"""
2020

21-
import warnings
22-
23-
from functools import partial
24-
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
21+
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union, cast
2522

2623
import numpy as np
2724

2825
from aesara.graph.basic import Constant, Variable
29-
from aesara.graph.op import Op
3026
from aesara.tensor.var import TensorVariable
3127
from typing_extensions import TypeAlias
3228

33-
from pymc.aesaraf import change_rv_size, pandas_to_array
34-
from pymc.exceptions import ShapeError, ShapeWarning
29+
from pymc.aesaraf import pandas_to_array
3530

3631
__all__ = [
3732
"to_tuple",
@@ -525,19 +520,22 @@ def resize_from_dims(dims: WeakDims, ndim_implied: int, model) -> Tuple[StrongSi
525520
# We don't have a way to know the names of implied
526521
# dimensions, so they will be `None`.
527522
dims = (*dims[:-1], *[None] * ndim_implied)
523+
sdims = cast(StrongDims, dims)
528524

529-
ndim_resize = len(dims) - ndim_implied
525+
ndim_resize = len(sdims) - ndim_implied
530526

531527
# All resize dims must be known already (numerically or symbolically).
532-
unknowndim_resize_dims = set(dims[:ndim_resize]) - set(model.dim_lengths)
528+
unknowndim_resize_dims = set(sdims[:ndim_resize]) - set(model.dim_lengths)
533529
if unknowndim_resize_dims:
534530
raise KeyError(
535531
f"Dimensions {unknowndim_resize_dims} are unknown to the model and cannot be used to specify a `size`."
536532
)
537533

538534
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
539-
resize_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_resize])
540-
return resize_shape, dims
535+
resize_shape: Tuple[Variable, ...] = tuple(
536+
model.dim_lengths[dname] for dname in sdims[:ndim_resize]
537+
)
538+
return resize_shape, sdims
541539

542540

543541
def resize_from_observed(
@@ -566,26 +564,30 @@ def resize_from_observed(
566564
return resize_shape, observed
567565

568566

569-
def find_size(shape=None, size=None, ndim_supp=None):
567+
def find_size(
568+
shape: Optional[WeakShape],
569+
size: Optional[StrongSize],
570+
ndim_supp: int,
571+
) -> Tuple[Optional[StrongSize], Optional[int], Optional[int], int]:
570572
"""Determines the size keyword argument for creating a Distribution.
571573
572574
Parameters
573575
----------
574-
shape : tuple
576+
shape
575577
A tuple specifying the final shape of a distribution
576-
size : tuple
578+
size
577579
A tuple specifying the size of a distribution
578580
ndim_supp : int
579581
The support dimension of the distribution.
580-
0 if a univariate distribution, 1 if a multivariate distribution.
582+
0 if a univariate distribution, 1 or higher for multivariate distributions.
581583
582584
Returns
583585
-------
584-
create_size : int
586+
create_size : int, optional
585587
The size argument to be passed to the distribution
586-
ndim_expected : int
588+
ndim_expected : int, optional
587589
Number of dimensions expected after distribution was created
588-
ndim_batch : int
590+
ndim_batch : int, optional
589591
Number of batch dimensions
590592
ndim_supp : int
591593
Number of support dimensions
@@ -614,84 +616,6 @@ def find_size(shape=None, size=None, ndim_supp=None):
614616
return create_size, ndim_expected, ndim_batch, ndim_supp
615617

616618

617-
def maybe_resize(
618-
rv_out: TensorVariable,
619-
rv_op: Op,
620-
dist_params,
621-
ndim_expected: int,
622-
ndim_batch,
623-
ndim_supp,
624-
shape,
625-
size,
626-
*,
627-
change_rv_size_fn=partial(change_rv_size, expand=True),
628-
**kwargs,
629-
):
630-
"""Resize a distribution if necessary.
631-
632-
Parameters
633-
----------
634-
rv_out : RandomVariable
635-
The RandomVariable to be resized if necessary
636-
rv_op : RandomVariable.__class__
637-
The RandomVariable class to recreate it
638-
dist_params : dict
639-
Input parameters to recreate the RandomVariable
640-
ndim_expected : int
641-
Number of dimensions expected after distribution was created
642-
ndim_batch : int
643-
Number of batch dimensions
644-
ndim_supp : int
645-
The support dimension of the distribution.
646-
0 if a univariate distribution, 1 if a multivariate distribution.
647-
shape : tuple
648-
A tuple specifying the final shape of a distribution
649-
size : tuple
650-
A tuple specifying the size of a distribution
651-
change_rv_size_fn: callable
652-
A function that returns an equivalent RV with a different size
653-
654-
Returns
655-
-------
656-
rv_out : int
657-
The size argument to be passed to the distribution
658-
"""
659-
ndim_actual = rv_out.ndim
660-
ndims_unexpected = ndim_actual != ndim_expected
661-
662-
if shape is not None and ndims_unexpected:
663-
if Ellipsis in shape:
664-
# Resize and we're done!
665-
rv_out = change_rv_size_fn(rv_var=rv_out, new_size=shape[:-1])
666-
else:
667-
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
668-
# Recreate the RV without passing `size` to created it with just the implied dimensions.
669-
rv_out = rv_op(*dist_params, size=None, **kwargs)
670-
671-
# Now resize by any remaining "extra" dimensions that were not implied from support and parameters
672-
if rv_out.ndim < ndim_expected:
673-
expand_shape = shape[: ndim_expected - rv_out.ndim]
674-
rv_out = change_rv_size_fn(rv_var=rv_out, new_size=expand_shape)
675-
if not rv_out.ndim == ndim_expected:
676-
raise ShapeError(
677-
f"Failed to create the RV with the expected dimensionality. "
678-
f"This indicates a severe problem. Please open an issue.",
679-
actual=ndim_actual,
680-
expected=ndim_batch + ndim_supp,
681-
)
682-
683-
# Warn about the edge cases where the RV Op creates more dimensions than
684-
# it should based on `size` and `RVOp.ndim_supp`.
685-
if size is not None and ndims_unexpected:
686-
warnings.warn(
687-
f"You may have expected a ({len(tuple(size))}+{ndim_supp})-dimensional RV, but the resulting RV will be {ndim_actual}-dimensional."
688-
' To silence this warning use `warnings.simplefilter("ignore", pm.ShapeWarning)`.',
689-
ShapeWarning,
690-
)
691-
692-
return rv_out
693-
694-
695619
def rv_size_is_none(size: Variable) -> bool:
696620
"""Check wether an rv size is None (ie., at.Constant([]))"""
697621
return isinstance(size, Constant) and size.data.size == 0

pymc/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import functools
1616

17-
from typing import Dict, List, Tuple, Union
17+
from typing import Dict, List, Tuple, Union, cast
1818

1919
import arviz
2020
import cloudpickle

0 commit comments

Comments
 (0)