|
18 | 18 | samples from probability distributions for stochastic nodes in PyMC.
|
19 | 19 | """
|
20 | 20 |
|
21 |
| -import warnings |
22 |
| - |
23 |
| -from functools import partial |
24 |
| -from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union |
| 21 | +from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union, cast |
25 | 22 |
|
26 | 23 | import numpy as np
|
27 | 24 |
|
28 | 25 | from aesara.graph.basic import Constant, Variable
|
29 |
| -from aesara.graph.op import Op |
30 | 26 | from aesara.tensor.var import TensorVariable
|
31 | 27 | from typing_extensions import TypeAlias
|
32 | 28 |
|
33 |
| -from pymc.aesaraf import change_rv_size, pandas_to_array |
34 |
| -from pymc.exceptions import ShapeError, ShapeWarning |
| 29 | +from pymc.aesaraf import pandas_to_array |
35 | 30 |
|
36 | 31 | __all__ = [
|
37 | 32 | "to_tuple",
|
@@ -525,19 +520,22 @@ def resize_from_dims(dims: WeakDims, ndim_implied: int, model) -> Tuple[StrongSi
|
525 | 520 | # We don't have a way to know the names of implied
|
526 | 521 | # dimensions, so they will be `None`.
|
527 | 522 | dims = (*dims[:-1], *[None] * ndim_implied)
|
| 523 | + sdims = cast(StrongDims, dims) |
528 | 524 |
|
529 |
| - ndim_resize = len(dims) - ndim_implied |
| 525 | + ndim_resize = len(sdims) - ndim_implied |
530 | 526 |
|
531 | 527 | # All resize dims must be known already (numerically or symbolically).
|
532 |
| - unknowndim_resize_dims = set(dims[:ndim_resize]) - set(model.dim_lengths) |
| 528 | + unknowndim_resize_dims = set(sdims[:ndim_resize]) - set(model.dim_lengths) |
533 | 529 | if unknowndim_resize_dims:
|
534 | 530 | raise KeyError(
|
535 | 531 | f"Dimensions {unknowndim_resize_dims} are unknown to the model and cannot be used to specify a `size`."
|
536 | 532 | )
|
537 | 533 |
|
538 | 534 | # The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
|
539 |
| - resize_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_resize]) |
540 |
| - return resize_shape, dims |
| 535 | + resize_shape: Tuple[Variable, ...] = tuple( |
| 536 | + model.dim_lengths[dname] for dname in sdims[:ndim_resize] |
| 537 | + ) |
| 538 | + return resize_shape, sdims |
541 | 539 |
|
542 | 540 |
|
543 | 541 | def resize_from_observed(
|
@@ -566,26 +564,30 @@ def resize_from_observed(
|
566 | 564 | return resize_shape, observed
|
567 | 565 |
|
568 | 566 |
|
569 |
| -def find_size(shape=None, size=None, ndim_supp=None): |
| 567 | +def find_size( |
| 568 | + shape: Optional[WeakShape], |
| 569 | + size: Optional[StrongSize], |
| 570 | + ndim_supp: int, |
| 571 | +) -> Tuple[Optional[StrongSize], Optional[int], Optional[int], int]: |
570 | 572 | """Determines the size keyword argument for creating a Distribution.
|
571 | 573 |
|
572 | 574 | Parameters
|
573 | 575 | ----------
|
574 |
| - shape : tuple |
| 576 | + shape |
575 | 577 | A tuple specifying the final shape of a distribution
|
576 |
| - size : tuple |
| 578 | + size |
577 | 579 | A tuple specifying the size of a distribution
|
578 | 580 | ndim_supp : int
|
579 | 581 | The support dimension of the distribution.
|
580 |
| - 0 if a univariate distribution, 1 if a multivariate distribution. |
| 582 | + 0 if a univariate distribution, 1 or higher for multivariate distributions. |
581 | 583 |
|
582 | 584 | Returns
|
583 | 585 | -------
|
584 |
| - create_size : int |
| 586 | + create_size : int, optional |
585 | 587 | The size argument to be passed to the distribution
|
586 |
| - ndim_expected : int |
| 588 | + ndim_expected : int, optional |
587 | 589 | Number of dimensions expected after distribution was created
|
588 |
| - ndim_batch : int |
| 590 | + ndim_batch : int, optional |
589 | 591 | Number of batch dimensions
|
590 | 592 | ndim_supp : int
|
591 | 593 | Number of support dimensions
|
@@ -614,84 +616,6 @@ def find_size(shape=None, size=None, ndim_supp=None):
|
614 | 616 | return create_size, ndim_expected, ndim_batch, ndim_supp
|
615 | 617 |
|
616 | 618 |
|
617 |
| -def maybe_resize( |
618 |
| - rv_out: TensorVariable, |
619 |
| - rv_op: Op, |
620 |
| - dist_params, |
621 |
| - ndim_expected: int, |
622 |
| - ndim_batch, |
623 |
| - ndim_supp, |
624 |
| - shape, |
625 |
| - size, |
626 |
| - *, |
627 |
| - change_rv_size_fn=partial(change_rv_size, expand=True), |
628 |
| - **kwargs, |
629 |
| -): |
630 |
| - """Resize a distribution if necessary. |
631 |
| -
|
632 |
| - Parameters |
633 |
| - ---------- |
634 |
| - rv_out : RandomVariable |
635 |
| - The RandomVariable to be resized if necessary |
636 |
| - rv_op : RandomVariable.__class__ |
637 |
| - The RandomVariable class to recreate it |
638 |
| - dist_params : dict |
639 |
| - Input parameters to recreate the RandomVariable |
640 |
| - ndim_expected : int |
641 |
| - Number of dimensions expected after distribution was created |
642 |
| - ndim_batch : int |
643 |
| - Number of batch dimensions |
644 |
| - ndim_supp : int |
645 |
| - The support dimension of the distribution. |
646 |
| - 0 if a univariate distribution, 1 if a multivariate distribution. |
647 |
| - shape : tuple |
648 |
| - A tuple specifying the final shape of a distribution |
649 |
| - size : tuple |
650 |
| - A tuple specifying the size of a distribution |
651 |
| - change_rv_size_fn: callable |
652 |
| - A function that returns an equivalent RV with a different size |
653 |
| -
|
654 |
| - Returns |
655 |
| - ------- |
656 |
| - rv_out : int |
657 |
| - The size argument to be passed to the distribution |
658 |
| - """ |
659 |
| - ndim_actual = rv_out.ndim |
660 |
| - ndims_unexpected = ndim_actual != ndim_expected |
661 |
| - |
662 |
| - if shape is not None and ndims_unexpected: |
663 |
| - if Ellipsis in shape: |
664 |
| - # Resize and we're done! |
665 |
| - rv_out = change_rv_size_fn(rv_var=rv_out, new_size=shape[:-1]) |
666 |
| - else: |
667 |
| - # This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)). |
668 |
| - # Recreate the RV without passing `size` to created it with just the implied dimensions. |
669 |
| - rv_out = rv_op(*dist_params, size=None, **kwargs) |
670 |
| - |
671 |
| - # Now resize by any remaining "extra" dimensions that were not implied from support and parameters |
672 |
| - if rv_out.ndim < ndim_expected: |
673 |
| - expand_shape = shape[: ndim_expected - rv_out.ndim] |
674 |
| - rv_out = change_rv_size_fn(rv_var=rv_out, new_size=expand_shape) |
675 |
| - if not rv_out.ndim == ndim_expected: |
676 |
| - raise ShapeError( |
677 |
| - f"Failed to create the RV with the expected dimensionality. " |
678 |
| - f"This indicates a severe problem. Please open an issue.", |
679 |
| - actual=ndim_actual, |
680 |
| - expected=ndim_batch + ndim_supp, |
681 |
| - ) |
682 |
| - |
683 |
| - # Warn about the edge cases where the RV Op creates more dimensions than |
684 |
| - # it should based on `size` and `RVOp.ndim_supp`. |
685 |
| - if size is not None and ndims_unexpected: |
686 |
| - warnings.warn( |
687 |
| - f"You may have expected a ({len(tuple(size))}+{ndim_supp})-dimensional RV, but the resulting RV will be {ndim_actual}-dimensional." |
688 |
| - ' To silence this warning use `warnings.simplefilter("ignore", pm.ShapeWarning)`.', |
689 |
| - ShapeWarning, |
690 |
| - ) |
691 |
| - |
692 |
| - return rv_out |
693 |
| - |
694 |
| - |
695 | 619 | def rv_size_is_none(size: Variable) -> bool:
|
696 | 620 | """Check wether an rv size is None (ie., at.Constant([]))"""
|
697 | 621 | return isinstance(size, Constant) and size.data.size == 0
|
0 commit comments