From 862ff2e73f94bcff3e29e2e031747be35c96c171 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sun, 19 Jul 2020 19:35:49 -0500 Subject: [PATCH] add_to_inference_data for fast_sample_posterior_predictive --- pymc3/distributions/posterior_predictive.py | 37 ++++++++++++++++----- pymc3/tests/test_sampling.py | 17 ++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/pymc3/distributions/posterior_predictive.py b/pymc3/distributions/posterior_predictive.py index 5c7bab07ef..98155f7990 100644 --- a/pymc3/distributions/posterior_predictive.py +++ b/pymc3/distributions/posterior_predictive.py @@ -22,6 +22,8 @@ import numpy as np import theano import theano.tensor as tt +import arviz as az + from xarray import Dataset from arviz import InferenceData @@ -173,7 +175,8 @@ def fast_sample_posterior_predictive( var_names: Optional[List[str]] = None, keep_size: bool = False, random_seed=None, -) -> Dict[str, np.ndarray]: + add_to_inference_data: Optional[bool]=None, +) -> Union[Dict[str, np.ndarray], InferenceData]: """Generate posterior predictive samples from a model given a trace. This is a vectorized alternative to the standard ``sample_posterior_predictive`` function. @@ -200,12 +203,17 @@ def fast_sample_posterior_predictive( data: ``(nchains, ndraws, ...)``. random_seed: int Seed for the random number generator. + add_to_inference_data : bool, Optional + If true or unsupplied, and the ``trace`` argument is an ``InferenceData``, return a new + ``InferenceData`` object with the posterior predictive samples in the right group. + Defaults to True, *if* the ``trace`` is an ``InferenceData``, else False. Returns ------- - samples: dict + samples: dict or InferenceData Dictionary with the variable names as keys, and values numpy arrays containing - posterior predictive samples. + posterior predictive samples. See discussion of ``add_to_inference_data`` argument + for explanation of ``InferenceData`` return. """ ### Implementation note: primarily this function canonicalizes the arguments: @@ -217,16 +225,24 @@ def fast_sample_posterior_predictive( ### greater than the number of samples in the trace parameter, we sample repeatedly. This ### makes the shape issues just a little easier to deal with. + if not isinstance(trace, InferenceData): + if add_to_inference_data: + raise IncorrectArgumentsError("add_to_inference_data is only valid if an InferenceData is supplied.") + if isinstance(trace, InferenceData): nchains, ndraws = chains_and_samples(trace) - trace = dataset_to_point_dict(trace.posterior) + _trace0 = dataset_to_point_dict(trace.posterior) + if add_to_inference_data is None: + add_to_inference_data = True elif isinstance(trace, Dataset): nchains, ndraws = chains_and_samples(trace) - trace = dataset_to_point_dict(trace) + _trace0 = dataset_to_point_dict(trace) elif isinstance(trace, MultiTrace): + _trace0 = trace nchains = trace.nchains ndraws = len(trace) else: + _trace0 = trace if keep_size: # arguably this should be just a warning. raise IncorrectArgumentsError( @@ -242,10 +258,10 @@ def fast_sample_posterior_predictive( "Should not specify both keep_size and samples arguments" ) - if isinstance(trace, list) and all((isinstance(x, dict) for x in trace)): - _trace = _TraceDict(point_list=trace) - elif isinstance(trace, MultiTrace): - _trace = _TraceDict(multi_trace=trace) + if isinstance(_trace0, list) and all((isinstance(x, dict) for x in _trace0)): + _trace = _TraceDict(point_list=_trace0) + elif isinstance(_trace0, MultiTrace): + _trace = _TraceDict(multi_trace=_trace0) else: raise TypeError( "Unable to generate posterior predictive samples from argument of type %s" @@ -322,6 +338,9 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None: for k, ary in ppc_trace.items() } # this gets us a Dict[str, np.ndarray] instead of my wrapped equiv. + if add_to_inference_data: + assert isinstance(trace, InferenceData) + return az.concat(trace, az.from_dict(posterior_predictive=ppc_trace.data)) return ppc_trace.data diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 91ce994862..7e48ecd861 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -387,6 +387,21 @@ def test_normal_scalar(self): assert ppc["a"].shape == (nchains, ndraws) ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True) assert ppc["a"].shape == (nchains, ndraws) + # test returning an InferenceData object + ppc_idata = pm.fast_sample_posterior_predictive(idata, add_to_inference_data=True) + assert isinstance(ppc_idata, az.InferenceData) + assert hasattr(ppc_idata, 'posterior_predictive') + assert set(ppc_idata.posterior_predictive.data_vars.keys()) == set(["a"]) + # if you pass an InferenceData, you get one back by default -- + # this may be controversial + ppc_idata = pm.fast_sample_posterior_predictive(idata) + assert isinstance(ppc_idata, az.InferenceData) + assert hasattr(ppc_idata, 'posterior_predictive') + assert set(ppc_idata.posterior_predictive.data_vars.keys()) == set(["a"]) + # but you don't have to get one if you don't want + ppc = pm.fast_sample_posterior_predictive(idata, keep_size=True, add_to_inference_data=False) + assert not isinstance(ppc, az.InferenceData) + assert ppc["a"].shape == (nchains, ndraws) # test default case ppc = pm.sample_posterior_predictive(trace, var_names=["a"]) @@ -489,6 +504,8 @@ def test_exceptions(self, caplog): ppc = pm.sample_posterior_predictive(bad_trace) with pytest.raises(TypeError): ppc = pm.fast_sample_posterior_predictive(bad_trace) + with pytest.raises(IncorrectArgumentsError): + ppc_idata = pm.fast_sample_posterior_predictive(trace, add_to_inference_data=True) def test_vector_observed(self): with pm.Model() as model: