From 9f59a7b669f5587bac22d410005c044570053c51 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Sat, 5 Dec 2020 10:10:48 +0100 Subject: [PATCH 1/5] informative warnings on bound method logp in DensityDist --- pymc3/distributions/distribution.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 066dbfd26d..47c2bf7d96 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import numbers import contextvars import dill import inspect +import sys +import types from typing import TYPE_CHECKING +import warnings if TYPE_CHECKING: from typing import Optional, Callable @@ -505,6 +509,15 @@ def __init__( dtype = theano.config.floatX super().__init__(shape, dtype, testval, *args, **kwargs) self.logp = logp + if type(self.logp) == types.MethodType: + if sys.platform != "linux": + warnings.warn("You are passing a bound method as logp for DensityDist, this can lead to " + + "errors when sampling on platforms other than Linux. Consider using a " + + "plain function instead, or subclass Distribution.") + elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext: + warnings.warn("You are passing a bound method as logp for DensityDist, this can lead to " + + "errors when sampling when multiprocessing cannot rely on forking. Consider using a " + + "plain function instead, or subclass Distribution.") self.rand = random self.wrap_random_with_dist_shape = wrap_random_with_dist_shape self.check_shape_in_random = check_shape_in_random @@ -513,7 +526,13 @@ def __getstate__(self): # We use dill to serialize the logp function, as this is almost # always defined in the notebook and won't be pickled correctly. # Fix https://github.com/pymc-devs/pymc3/issues/3844 - logp = dill.dumps(self.logp) + try: + logp = dill.dumps(self.logp) + except RecursionError as err: + if type(self.logp) == types.MethodType: + raise ValueError("logp for DensityDist is a bound method, leading to RecursionError while serializing") from err + else: + raise err vals = self.__dict__.copy() vals["logp"] = logp return vals From 3e0c166e0782cc8caec1131df4e6c09e3da393cb Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Sat, 5 Dec 2020 10:12:17 +0100 Subject: [PATCH 2/5] black --- pymc3/distributions/distribution.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 47c2bf7d96..734722ae49 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -511,13 +511,17 @@ def __init__( self.logp = logp if type(self.logp) == types.MethodType: if sys.platform != "linux": - warnings.warn("You are passing a bound method as logp for DensityDist, this can lead to " + - "errors when sampling on platforms other than Linux. Consider using a " + - "plain function instead, or subclass Distribution.") + warnings.warn( + "You are passing a bound method as logp for DensityDist, this can lead to " + + "errors when sampling on platforms other than Linux. Consider using a " + + "plain function instead, or subclass Distribution." + ) elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext: - warnings.warn("You are passing a bound method as logp for DensityDist, this can lead to " + - "errors when sampling when multiprocessing cannot rely on forking. Consider using a " + - "plain function instead, or subclass Distribution.") + warnings.warn( + "You are passing a bound method as logp for DensityDist, this can lead to " + + "errors when sampling when multiprocessing cannot rely on forking. Consider using a " + + "plain function instead, or subclass Distribution." + ) self.rand = random self.wrap_random_with_dist_shape = wrap_random_with_dist_shape self.check_shape_in_random = check_shape_in_random @@ -530,7 +534,9 @@ def __getstate__(self): logp = dill.dumps(self.logp) except RecursionError as err: if type(self.logp) == types.MethodType: - raise ValueError("logp for DensityDist is a bound method, leading to RecursionError while serializing") from err + raise ValueError( + "logp for DensityDist is a bound method, leading to RecursionError while serializing" + ) from err else: raise err vals = self.__dict__.copy() From fa9ffd46cb4a1c52f50d99c338868809a0c8e6f0 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Sat, 5 Dec 2020 10:18:08 +0100 Subject: [PATCH 3/5] run test on single core only to avoid dill error on windows/macos --- pymc3/tests/test_distributions_random.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 84bd8bc117..2a00d577ba 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -1171,7 +1171,7 @@ def test_density_dist_with_random_sampleable(self, shape): shape=shape, random=normal_dist.random, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 size = 100 @@ -1194,7 +1194,7 @@ def test_density_dist_with_random_sampleable_failure(self, shape): random=normal_dist.random, wrap_random_with_dist_shape=False, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 with pytest.raises(RuntimeError): @@ -1217,7 +1217,7 @@ def test_density_dist_with_random_sampleable_hidden_error(self, shape): wrap_random_with_dist_shape=False, check_shape_in_random=False, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 ppc = pm.sample_posterior_predictive(trace, samples=samples, model=model) @@ -1240,7 +1240,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success(self): random=rvs, wrap_random_with_dist_shape=False, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 size = 100 @@ -1260,7 +1260,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success_fast(self): random=rvs, wrap_random_with_dist_shape=False, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 size = 100 From 2b1d1a308cb2b880db652f7ebb03830b5438572d Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Sat, 5 Dec 2020 10:54:14 +0100 Subject: [PATCH 4/5] adding tests for DensityDist serialize recursion handling --- pymc3/tests/test_parallel_sampling.py | 42 +++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py index b5de8332cc..f8063663e8 100644 --- a/pymc3/tests/test_parallel_sampling.py +++ b/pymc3/tests/test_parallel_sampling.py @@ -159,3 +159,45 @@ def test_iterator(): with sampler: for draw in sampler: pass + + +def test_spawn_densitydist_function(): + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + + def func(x): + return -2 * (x ** 2).sum() + + obs = pm.DensityDist("density_dist", func, observed=np.random.randn(100)) + trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") + + +@pytest.mark.xfail(raises=ValueError) +def test_spawn_densitydist_bound_method(): + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + normal_dist = pm.Normal.dist(mu, 1) + obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) + trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") + + +# cannot test this properly: monkeypatching sys.platform messes up Theano +# def test_spawn_densitydist_syswarning(monkeypatch): +# monkeypatch.setattr(sys, "platform", "win32") +# with pm.Model() as model: +# mu = pm.Normal('mu', 0, 1) +# normal_dist = pm.Normal.dist(mu, 1) +# with pytest.warns(UserWarning) as w: +# obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100)) +# assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0] + + +def test_spawn_densitydist_mpctxwarning(monkeypatch): + ctx = multiprocessing.get_context("spawn") + monkeypatch.setattr(multiprocessing, "get_context", lambda: ctx) + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + normal_dist = pm.Normal.dist(mu, 1) + with pytest.warns(UserWarning) as w: + obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) + assert len(w) == 1 and "errors when sampling when multiprocessing" in w[0].message.args[0] From d70728282ee7db8d4f9f79822693cd15c826df53 Mon Sep 17 00:00:00 2001 From: Eelke Spaak Date: Sat, 5 Dec 2020 10:57:21 +0100 Subject: [PATCH 5/5] forgot a test --- pymc3/tests/test_distributions_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 2a00d577ba..a789674095 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -1273,7 +1273,7 @@ def test_density_dist_without_random_not_sampleable(self): mu = pm.Normal("mu", 0, 1) normal_dist = pm.Normal.dist(mu, 1) pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 with pytest.raises(ValueError):