|
19 | 19 |
|
20 | 20 | from abc import ABCMeta
|
21 | 21 | from functools import singledispatch
|
22 |
| -from typing import Callable, Iterable, Optional, Sequence |
| 22 | +from typing import Callable, Iterable, Optional, Sequence, Tuple, Union |
23 | 23 |
|
24 | 24 | import aesara
|
| 25 | +import numpy as np |
25 | 26 |
|
26 | 27 | from aeppl.logprob import _logcdf, _logprob
|
27 | 28 | from aesara import tensor as at
|
| 29 | +from aesara.graph.basic import Variable |
28 | 30 | from aesara.tensor.basic import as_tensor_variable
|
29 | 31 | from aesara.tensor.elemwise import Elemwise
|
30 | 32 | from aesara.tensor.random.op import RandomVariable
|
|
36 | 38 | Dims,
|
37 | 39 | Shape,
|
38 | 40 | Size,
|
| 41 | + StrongShape, |
| 42 | + WeakDims, |
39 | 43 | convert_dims,
|
40 | 44 | convert_shape,
|
41 | 45 | convert_size,
|
@@ -133,6 +137,37 @@ def fn(*args, **kwargs):
|
133 | 137 | return fn
|
134 | 138 |
|
135 | 139 |
|
| 140 | +def _make_rv_and_resize_shape( |
| 141 | + *, |
| 142 | + cls, |
| 143 | + dims: Optional[Dims], |
| 144 | + model, |
| 145 | + observed, |
| 146 | + args, |
| 147 | + **kwargs, |
| 148 | +) -> Tuple[Variable, Optional[WeakDims], Optional[Union[np.ndarray, Variable]], StrongShape]: |
| 149 | + """Creates the RV and processes dims or observed to determine a resize shape.""" |
| 150 | + # Create the RV without dims information, because that's not something tracked at the Aesara level. |
| 151 | + # If necessary we'll later replicate to a different size implied by already known dims. |
| 152 | + rv_out = cls.dist(*args, **kwargs) |
| 153 | + ndim_actual = rv_out.ndim |
| 154 | + resize_shape = None |
| 155 | + |
| 156 | + # # `dims` are only available with this API, because `.dist()` can be used |
| 157 | + # # without a modelcontext and dims are not tracked at the Aesara level. |
| 158 | + dims = convert_dims(dims) |
| 159 | + dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None |
| 160 | + if dims is not None: |
| 161 | + if dims_can_resize: |
| 162 | + resize_shape, dims = resize_from_dims(dims, ndim_actual, model) |
| 163 | + elif Ellipsis in dims: |
| 164 | + # Replace ... with None entries to match the actual dimensionality. |
| 165 | + dims = (*dims[:-1], *[None] * ndim_actual)[:ndim_actual] |
| 166 | + elif observed is not None: |
| 167 | + resize_shape, observed = resize_from_observed(observed, ndim_actual) |
| 168 | + return rv_out, dims, observed, resize_shape |
| 169 | + |
| 170 | + |
136 | 171 | class Distribution(metaclass=DistributionMeta):
|
137 | 172 | """Statistical distribution"""
|
138 | 173 |
|
@@ -213,28 +248,11 @@ def __new__(
|
213 | 248 | if rng is None:
|
214 | 249 | rng = model.next_rng()
|
215 | 250 |
|
216 |
| - if dims is not None and "shape" in kwargs: |
217 |
| - raise ValueError( |
218 |
| - f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!" |
219 |
| - ) |
220 |
| - if dims is not None and "size" in kwargs: |
221 |
| - raise ValueError( |
222 |
| - f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!" |
223 |
| - ) |
224 |
| - dims = convert_dims(dims) |
225 |
| - |
226 |
| - # Create the RV without dims information, because that's not something tracked at the Aesara level. |
227 |
| - # If necessary we'll later replicate to a different size implied by already known dims. |
228 |
| - rv_out = cls.dist(*args, rng=rng, **kwargs) |
229 |
| - ndim_actual = rv_out.ndim |
230 |
| - resize_shape = None |
231 |
| - |
232 |
| - # `dims` are only available with this API, because `.dist()` can be used |
233 |
| - # without a modelcontext and dims are not tracked at the Aesara level. |
234 |
| - if dims is not None: |
235 |
| - ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model) |
236 |
| - elif observed is not None: |
237 |
| - ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual) |
| 251 | + # Create the RV and process dims and observed to determine |
| 252 | + # a shape by which the created RV may need to be resized. |
| 253 | + rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( |
| 254 | + cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs |
| 255 | + ) |
238 | 256 |
|
239 | 257 | if resize_shape:
|
240 | 258 | # A batch size was specified through `dims`, or implied by `observed`.
|
@@ -456,35 +474,18 @@ def __new__(
|
456 | 474 | if not isinstance(name, string_types):
|
457 | 475 | raise TypeError(f"Name needs to be a string but got: {name}")
|
458 | 476 |
|
459 |
| - if dims is not None and "shape" in kwargs: |
460 |
| - raise ValueError( |
461 |
| - f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!" |
462 |
| - ) |
463 |
| - if dims is not None and "size" in kwargs: |
464 |
| - raise ValueError( |
465 |
| - f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!" |
466 |
| - ) |
467 |
| - dims = convert_dims(dims) |
468 |
| - |
469 | 477 | if rngs is None:
|
470 | 478 | # Create a temporary rv to obtain number of rngs needed
|
471 | 479 | temp_graph = cls.dist(*args, rngs=None, **kwargs)
|
472 | 480 | rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)]
|
473 | 481 | elif not isinstance(rngs, (list, tuple)):
|
474 | 482 | rngs = [rngs]
|
475 | 483 |
|
476 |
| - # Create the RV without dims information, because that's not something tracked at the Aesara level. |
477 |
| - # If necessary we'll later replicate to a different size implied by already known dims. |
478 |
| - rv_out = cls.dist(*args, rngs=rngs, **kwargs) |
479 |
| - ndim_actual = rv_out.ndim |
480 |
| - resize_shape = None |
481 |
| - |
482 |
| - # # `dims` are only available with this API, because `.dist()` can be used |
483 |
| - # # without a modelcontext and dims are not tracked at the Aesara level. |
484 |
| - if dims is not None: |
485 |
| - ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model) |
486 |
| - elif observed is not None: |
487 |
| - ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual) |
| 484 | + # Create the RV and process dims and observed to determine |
| 485 | + # a shape by which the created RV may need to be resized. |
| 486 | + rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( |
| 487 | + cls=cls, dims=dims, model=model, observed=observed, args=args, rngs=rngs, **kwargs |
| 488 | + ) |
488 | 489 |
|
489 | 490 | if resize_shape:
|
490 | 491 | # A batch size was specified through `dims`, or implied by `observed`.
|
|
0 commit comments