diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml index 46b3d2d826..a104331b5c 100644 --- a/conda-envs/environment-dev-py37.yml +++ b/conda-envs/environment-dev-py37.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.18 -- aesara>=2.2.6 +- aesara=2.3.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml index 6b98167eb3..488cfa9854 100644 --- a/conda-envs/environment-dev-py38.yml +++ b/conda-envs/environment-dev-py38.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.18 -- aesara>=2.2.6 +- aesara=2.3.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml index dd700dba6a..a0da6f8a96 100644 --- a/conda-envs/environment-dev-py39.yml +++ b/conda-envs/environment-dev-py39.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.18 -- aesara>=2.2.6 +- aesara=2.3.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml index 9aa684412d..7e2656a848 100644 --- a/conda-envs/environment-test-py37.yml +++ b/conda-envs/environment-test-py37.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.18 -- aesara>=2.2.6 +- aesara=2.3.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 8df3a4759e..f0cb8b4c4d 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.18 -- aesara>=2.2.6 +- aesara=2.3.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index e8d50dc2c5..a5f7ddd492 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.18 -- aesara>=2.2.6 +- aesara=2.3.2 - arviz>=0.11.4 - cachetools - cloudpickle diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml index 855187c643..0716a37b0c 100644 --- a/conda-envs/windows-environment-dev-py38.yml +++ b/conda-envs/windows-environment-dev-py38.yml @@ -5,7 +5,7 @@ channels: dependencies: # base dependencies (see install guide for Windows) - aeppl=0.0.18 -- aesara>=2.2.6 +- aesara=2.3.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml index e23664e815..b61241293c 100644 --- a/conda-envs/windows-environment-test-py38.yml +++ b/conda-envs/windows-environment-test-py38.yml @@ -5,7 +5,7 @@ channels: dependencies: # base dependencies (see install guide for Windows) - aeppl=0.0.18 -- aesara>=2.2.6 +- aesara=2.3.2 - arviz>=0.11.2 - cachetools - cloudpickle diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 8800d705da..06dcbec5e4 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -717,16 +717,16 @@ def dist( def get_moment(rv, size, mu, sigma, lower, upper): mu, _, lower, upper = at.broadcast_arrays(mu, sigma, lower, upper) moment = at.switch( - at.isinf(lower), + at.eq(lower, -np.inf), at.switch( - at.isinf(upper), + at.eq(upper, np.inf), # lower = -inf, upper = inf mu, # lower = -inf, upper = x upper - 1, ), at.switch( - at.isinf(upper), + at.eq(upper, np.inf), # lower = x, upper = inf lower + 1, # lower = x, upper = x diff --git a/pymc/initial_point.py b/pymc/initial_point.py index a53f9af33e..c73aa926e7 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import warnings from typing import Callable, Dict, List, Optional, Sequence, Set, Union @@ -131,7 +132,7 @@ def make_initial_point_fn( model, overrides: Optional[StartDict] = None, jitter_rvs: Optional[Set[TensorVariable]] = None, - default_strategy: str = "prior", + default_strategy: str = "moment", return_transformed: bool = True, ) -> Callable: """Create seeded function that computes initial values for all free model variables. @@ -226,7 +227,7 @@ def make_initial_point_expression( rvs_to_values: Dict[TensorVariable, TensorVariable], initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], jitter_rvs: Set[TensorVariable] = None, - default_strategy: str = "prior", + default_strategy: str = "moment", return_transformed: bool = False, ) -> List[TensorVariable]: """Creates the tensor variables that need to be evaluated to obtain an initial point. @@ -269,7 +270,19 @@ def make_initial_point_expression( if isinstance(strategy, str): if strategy == "moment": - value = get_moment(variable) + try: + value = get_moment(variable) + except NotImplementedError: + warnings.warn( + f"Moment not defined for variable {variable} of type " + f"{variable.owner.op.__class__.__name__}, defaulting to " + f"a draw from the prior. This can lead to difficulties " + f"during tuning. You can manually define an initval or " + f"implement a get_moment dispatched function for this " + f"distribution.", + UserWarning, + ) + value = variable elif strategy == "prior": value = variable else: diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index e9537e6bf1..2188a931c4 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -8,7 +8,6 @@ import pymc as pm -from pymc import Simulator from pymc.distributions import ( AsymmetricLaplace, Bernoulli, @@ -50,6 +49,7 @@ Poisson, PolyaGamma, Rice, + Simulator, SkewNormal, StudentT, Triangular, @@ -62,13 +62,74 @@ ZeroInflatedNegativeBinomial, ZeroInflatedPoisson, ) -from pymc.distributions.distribution import get_moment +from pymc.distributions.distribution import _get_moment, get_moment +from pymc.distributions.logprob import logpt from pymc.distributions.multivariate import MvNormal from pymc.distributions.shape_utils import rv_size_is_none, to_tuple from pymc.initial_point import make_initial_point_fn from pymc.model import Model +def test_all_distributions_have_moments(): + import pymc.distributions as dist_module + + from pymc.distributions.distribution import DistributionMeta + + dists = (getattr(dist_module, dist) for dist in dist_module.__all__) + dists = (dist for dist in dists if isinstance(dist, DistributionMeta)) + missing_moments = { + dist for dist in dists if type(getattr(dist, "rv_op", None)) not in _get_moment.registry + } + + # Ignore super classes + missing_moments -= { + dist_module.Distribution, + dist_module.Discrete, + dist_module.Continuous, + dist_module.NoDistribution, + dist_module.DensityDist, + dist_module.simulator.Simulator, + } + + # Distributions that have not been refactored for V4 yet + not_implemented = { + dist_module.multivariate.LKJCorr, + dist_module.mixture.Mixture, + dist_module.mixture.MixtureSameFamily, + dist_module.mixture.NormalMixture, + dist_module.timeseries.AR, + dist_module.timeseries.AR1, + dist_module.timeseries.GARCH11, + dist_module.timeseries.GaussianRandomWalk, + dist_module.timeseries.MvGaussianRandomWalk, + dist_module.timeseries.MvStudentTRandomWalk, + } + + # Distributions that have been refactored but don't yet have moments + not_implemented |= { + dist_module.discrete.DiscreteWeibull, + dist_module.multivariate.CAR, + dist_module.multivariate.DirichletMultinomial, + dist_module.multivariate.KroneckerNormal, + dist_module.multivariate.Wishart, + } + + unexpected_implemented = not_implemented - missing_moments + if unexpected_implemented: + raise Exception( + f"Distributions {unexpected_implemented} have a `get_moment` implemented. " + "This test must be updated to expect this." + ) + + unexpected_not_implemented = missing_moments - not_implemented + if unexpected_not_implemented: + raise NotImplementedError( + f"Unexpected by this test, distributions {unexpected_not_implemented} do " + "not have a `get_moment` implementation. Either add a moment or filter " + "these distributions in this test." + ) + + def test_rv_size_is_none(): rv = Normal.dist(0, 1, size=None) assert rv_size_is_none(rv.owner.inputs[1]) @@ -85,20 +146,25 @@ def test_rv_size_is_none(): assert not rv_size_is_none(rv.owner.inputs[1]) -def assert_moment_is_expected(model, expected): +def assert_moment_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( model=model, return_transformed=False, default_strategy="moment", ) - result = fn(0)["x"] + moment = fn(0)["x"] expected = np.asarray(expected) try: random_draw = model["x"].eval() except NotImplementedError: - random_draw = result - assert result.shape == expected.shape == random_draw.shape - assert np.allclose(result, expected) + random_draw = moment + + assert moment.shape == expected.shape == random_draw.shape + assert np.allclose(moment, expected) + + if check_finite_logp: + logp_moment = logpt(model["x"], at.constant(moment), transformed=False).eval() + assert np.isfinite(logp_moment) @pytest.mark.parametrize( @@ -189,14 +255,13 @@ def test_halfstudentt_moment(nu, sigma, size, expected): assert_moment_is_expected(model, expected) -@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None") @pytest.mark.parametrize( "mu, sigma, lower, upper, size, expected", [ - (0.9, 1, -1, 1, None, 0), - (0.9, 1, -np.inf, np.inf, 5, np.full(5, 0.9)), + (0.9, 1, -5, 5, None, 0), + (1, np.ones(5), -10, np.inf, None, np.full(5, -9)), (np.arange(5), 1, None, 10, (2, 5), np.full((2, 5), 9)), - (1, np.ones(5), -10, np.inf, None, np.full((2, 5), -9)), + (1, 1, [-np.inf, -np.inf, -np.inf], 10, None, np.full(3, 9)), ], ) def test_truncatednormal_moment(mu, sigma, lower, upper, size, expected): @@ -371,11 +436,11 @@ def test_lognormal_moment(mu, sigma, size, expected): [ (1, None, 1), (1, 5, np.ones(5)), - (np.arange(5), None, np.arange(5)), + (np.arange(1, 5), None, np.arange(1, 5)), ( - np.arange(5), - (2, 5), - np.full((2, 5), np.arange(5)), + np.arange(1, 5), + (2, 4), + np.full((2, 4), np.arange(1, 5)), ), ], ) @@ -617,11 +682,11 @@ def test_logistic_moment(mu, s, size, expected): @pytest.mark.parametrize( "mu, nu, sigma, size, expected", [ - (1, 1, None, None, 2), + (1, 1, 1, None, 2), (1, 1, np.ones((2, 5)), None, np.full([2, 5], 2)), - (1, 1, None, 5, np.full(5, 2)), - (1, np.arange(1, 6), None, None, np.arange(2, 7)), - (1, np.arange(1, 6), None, (2, 5), np.full((2, 5), np.arange(2, 7))), + (1, 1, 3, 5, np.full(5, 2)), + (1, np.arange(1, 6), 5, None, np.arange(2, 7)), + (1, np.arange(1, 6), 1, (2, 5), np.full((2, 5), np.arange(2, 7))), ], ) def test_exgaussian_moment(mu, nu, sigma, size, expected): @@ -861,8 +926,10 @@ def test_interpolated_moment(x_points, pdf_points, size, expected): ) def test_mv_normal_moment(mu, cov, size, expected): with Model() as model: - MvNormal("x", mu=mu, cov=cov, size=size) - assert_moment_is_expected(model, expected) + x = MvNormal("x", mu=mu, cov=cov, size=size) + + # MvNormal logp is only impemented for up to 2D variables + assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) @pytest.mark.parametrize( @@ -898,8 +965,10 @@ def test_moyal_moment(mu, sigma, size, expected): ) def test_mvstudentt_moment(nu, mu, cov, size, expected): with Model() as model: - MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size) - assert_moment_is_expected(model, expected) + x = MvStudentT("x", nu=nu, mu=mu, cov=cov, size=size) + + # MvStudentT logp is only impemented for up to 2D variables + assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) def check_matrixnormal_moment(mu, rowchol, colchol, size, expected): @@ -1035,7 +1104,7 @@ def test_density_dist_default_moment_univariate(get_moment, size, expected): get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype) with Model() as model: DensityDist("x", get_moment=get_moment, size=size) - assert_moment_is_expected(model, expected) + assert_moment_is_expected(model, expected, check_finite_logp=False) @pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) diff --git a/pymc/tests/test_initial_point.py b/pymc/tests/test_initial_point.py index 899209399e..7700c33451 100644 --- a/pymc/tests/test_initial_point.py +++ b/pymc/tests/test_initial_point.py @@ -17,6 +17,8 @@ import numpy as np import pytest +from aesara.tensor.random.op import RandomVariable + import pymc as pm from pymc.distributions.distribution import get_moment @@ -255,6 +257,30 @@ def test_moment_from_dims(self, rv_cls): assert tuple(get_moment(rv).shape.eval()) == (4, 3) pass + def test_moment_not_implemented_fallback(self): + class MyNormalRV(RandomVariable): + name = "my_normal" + ndim_supp = 0 + ndims_params = [0, 0] + dtype = "floatX" + + @classmethod + def rng_fn(cls, rng, mu, sigma, size): + return np.pi + + class MyNormalDistribution(pm.Normal): + rv_op = MyNormalRV() + + with pm.Model() as m: + x = MyNormalDistribution("x", 0, 1, initval="moment") + + with pytest.warns( + UserWarning, match="Moment not defined for variable x of type MyNormalRV" + ): + res = m.recompute_initial_point() + + assert np.isclose(res["x"], np.pi) + def test_pickling_issue_5090(): with pm.Model() as model: diff --git a/requirements-dev.txt b/requirements-dev.txt index 19d903eb68..6e96b42357 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ # See that file for comments about the need/usage of each dependency. aeppl==0.0.18 -aesara>=2.2.6 +aesara==2.3.2 arviz>=0.11.4 cachetools>=4.2.1 cloudpickle diff --git a/requirements.txt b/requirements.txt index 5dbc6e1f57..6c08953863 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aeppl==0.0.18 -aesara>=2.2.6 +aesara==2.3.2 arviz>=0.11.4 cachetools>=4.2.1 cloudpickle