From 2b9e465b2383c2b5ba79a9279a3da0d269ce3d01 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 21 Sep 2021 21:54:25 +0200 Subject: [PATCH 1/5] Improve error message for missing get_moment implementations --- pymc/distributions/distribution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 2ffa092cd6..5f7cd7bc12 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -351,7 +351,9 @@ def dist( @singledispatch def _get_moment(op, rv, size, *rv_inputs) -> TensorVariable: - return None + raise NotImplementedError( + f"Random variable {rv} of type {op} has no get_moment implementation." + ) def get_moment(rv: TensorVariable) -> TensorVariable: From bb09a5a974f21e63299ba3ca5b8dba8964fe392d Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Tue, 21 Sep 2021 20:03:28 +0200 Subject: [PATCH 2/5] Add get_moment implementations for Normal, Uniform and Binomial Co-authored-by: Adrian Seyboldt --- pymc/distributions/continuous.py | 12 ++++++++++-- pymc/distributions/discrete.py | 4 ++++ pymc/tests/test_initvals.py | 5 +++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index cbb2371368..d5d5dd39ed 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -336,6 +336,11 @@ def logcdf(value, lower, upper): ), ) + def get_moment(value, size, lower, upper): + lower = at.full(size, lower, dtype=aesara.config.floatX) + upper = at.full(size, upper, dtype=aesara.config.floatX) + return (lower + upper) / 2 + class FlatRV(RandomVariable): name = "flat" @@ -366,7 +371,7 @@ def dist(cls, *, size=None, **kwargs): res.tag.test_value = np.full(size, floatX(0.0)) return res - def get_moment(rv, size, *rv_inputs) -> np.ndarray: + def get_moment(rv, size, *rv_inputs): return at.zeros(size, dtype=aesara.config.floatX) def logp(value): @@ -431,7 +436,7 @@ def dist(cls, *, size=None, **kwargs): res.tag.test_value = np.full(size, floatX(1.0)) return res - def get_moment(value_var, size, *rv_inputs) -> np.ndarray: + def get_moment(value_var, size, *rv_inputs): return at.ones(size, dtype=aesara.config.floatX) def logp(value): @@ -588,6 +593,9 @@ def logcdf(value, mu, sigma): 0 < sigma, ) + def get_moment(value_var, size, mu, sigma): + return at.full(size, mu, dtype=aesara.config.floatX) + class TruncatedNormalRV(RandomVariable): name = "truncated_normal" diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index d3269f6d28..f406dea477 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -394,6 +394,10 @@ def logcdf(value, p): p <= 1, ) + def get_moment(value, size, p): + p = at.full(size, p) + return at.switch(p < 0.5, np.int64(0), np.int64(1)) + def _distr_parameters_for_repr(self): return ["p"] diff --git a/pymc/tests/test_initvals.py b/pymc/tests/test_initvals.py index 9ebfb98e0d..b6cadf6d61 100644 --- a/pymc/tests/test_initvals.py +++ b/pymc/tests/test_initvals.py @@ -95,6 +95,11 @@ def test_automatically_assigned_test_values(self): class TestMoment: def test_basic(self): + # Standard distributions + rv = pm.Normal.dist(mu=2.3) + np.testing.assert_allclose(get_moment(rv).eval(), 2.3) + + # Special distributions rv = pm.Flat.dist() assert get_moment(rv).eval() == np.zeros(()) rv = pm.HalfFlat.dist() From beb5bb9ed521316e9837e33ec8d2036ead548b7a Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Wed, 15 Sep 2021 13:33:51 +0200 Subject: [PATCH 3/5] Add tests for (Half)Flat moments with symbolic dimensionality Closes #4993 --- pymc/tests/test_initvals.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pymc/tests/test_initvals.py b/pymc/tests/test_initvals.py index b6cadf6d61..6b4ef717a4 100644 --- a/pymc/tests/test_initvals.py +++ b/pymc/tests/test_initvals.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import aesara.tensor as at import numpy as np import pytest @@ -108,3 +109,33 @@ def test_basic(self): assert np.all(get_moment(rv).eval() == np.zeros((2, 4))) rv = pm.HalfFlat.dist(size=(2, 4)) assert np.all(get_moment(rv).eval() == np.ones((2, 4))) + + @pytest.mark.xfail(reason="Test values are still used for initvals.") + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_numeric_moment_shape(self, rv_cls): + rv = rv_cls.dist(shape=(2,)) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval()) == (2,) + + @pytest.mark.xfail(reason="Test values are still used for initvals.") + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_symbolic_moment_shape(self, rv_cls): + s = at.scalar() + rv = rv_cls.dist(shape=(s,)) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval({s: 4})) == (4,) + pass + + @pytest.mark.xfail(reason="Test values are still used for initvals.") + @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) + def test_moment_from_dims(self, rv_cls): + with pm.Model( + coords={ + "year": [2019, 2020, 2021, 2022], + "city": ["Bonn", "Paris", "Lisbon"], + } + ): + rv = rv_cls("rv", dims=("year", "city")) + assert not hasattr(rv.tag, "test_value") + assert tuple(get_moment(rv).shape.eval()) == (4, 3) + pass From 116ae6bff655876d8ecc772e5f7a9b9c910717bd Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 25 Sep 2021 14:30:52 +0100 Subject: [PATCH 4/5] Apply suggestions from code review --- pymc/distributions/continuous.py | 2 +- pymc/distributions/discrete.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index d5d5dd39ed..56ab7d00e1 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -337,7 +337,7 @@ def logcdf(value, lower, upper): ) def get_moment(value, size, lower, upper): - lower = at.full(size, lower, dtype=aesara.config.floatX) + lower = at.full(size, lower, dtype=value.owner.op.inputs[2]) upper = at.full(size, upper, dtype=aesara.config.floatX) return (lower + upper) / 2 diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index f406dea477..2f2714edd6 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -396,7 +396,7 @@ def logcdf(value, p): def get_moment(value, size, p): p = at.full(size, p) - return at.switch(p < 0.5, np.int64(0), np.int64(1)) + return at.switch(p < 0.5, at.zeros_like(value), at.ones_like(value)) def _distr_parameters_for_repr(self): return ["p"] From b9001f7d1191bfcf6a8dac1959762f9a0b7bb541 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Sat, 25 Sep 2021 22:28:27 +0100 Subject: [PATCH 5/5] Update pymc/distributions/continuous.py --- pymc/distributions/continuous.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 56ab7d00e1..d5d5dd39ed 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -337,7 +337,7 @@ def logcdf(value, lower, upper): ) def get_moment(value, size, lower, upper): - lower = at.full(size, lower, dtype=value.owner.op.inputs[2]) + lower = at.full(size, lower, dtype=aesara.config.floatX) upper = at.full(size, upper, dtype=aesara.config.floatX) return (lower + upper) / 2