diff --git a/pymc3/distributions/posterior_predictive.py b/pymc3/distributions/posterior_predictive.py index 7a4ab43d30..6b9cdf7d58 100644 --- a/pymc3/distributions/posterior_predictive.py +++ b/pymc3/distributions/posterior_predictive.py @@ -22,6 +22,8 @@ import numpy as np import theano import theano.tensor as tt +import arviz as az + from xarray import Dataset from arviz import InferenceData @@ -163,7 +165,8 @@ def fast_sample_posterior_predictive( var_names: Optional[List[str]] = None, keep_size: bool = False, random_seed=None, -) -> Dict[str, np.ndarray]: + add_to_inference_data: Optional[bool]=None, +) -> Union[Dict[str, np.ndarray], InferenceData]: """Generate posterior predictive samples from a model given a trace. This is a vectorized alternative to the standard ``sample_posterior_predictive`` function. @@ -190,12 +193,17 @@ def fast_sample_posterior_predictive( data: ``(nchains, ndraws, ...)``. random_seed: int Seed for the random number generator. + add_to_inference_data : bool, Optional + If true or unsupplied, and the ``trace`` argument is an ``InferenceData``, return a new + ``InferenceData`` object with the posterior predictive samples in the right group. + Defaults to True, *if* the ``trace`` is an ``InferenceData``, else False. Returns ------- - samples: dict + samples: dict or InferenceData Dictionary with the variable names as keys, and values numpy arrays containing - posterior predictive samples. + posterior predictive samples. See discussion of ``add_to_inference_data`` argument + for explanation of ``InferenceData`` return. """ ### Implementation note: primarily this function canonicalizes the arguments: @@ -207,16 +215,24 @@ def fast_sample_posterior_predictive( ### greater than the number of samples in the trace parameter, we sample repeatedly. This ### makes the shape issues just a little easier to deal with. + if not isinstance(trace, InferenceData): + if add_to_inference_data: + raise IncorrectArgumentsError("add_to_inference_data is only valid if an InferenceData is supplied.") + if isinstance(trace, InferenceData): nchains, ndraws = chains_and_samples(trace) - trace = dataset_to_point_dict(trace.posterior) + _trace0 = dataset_to_point_dict(trace.posterior) + if add_to_inference_data is None: + add_to_inference_data = True elif isinstance(trace, Dataset): nchains, ndraws = chains_and_samples(trace) - trace = dataset_to_point_dict(trace) + _trace0 = dataset_to_point_dict(trace) elif isinstance(trace, MultiTrace): + _trace0 = trace nchains = trace.nchains ndraws = len(trace) else: + _trace0 = trace if keep_size: # arguably this should be just a warning. raise IncorrectArgumentsError( @@ -230,10 +246,10 @@ def fast_sample_posterior_predictive( if keep_size and samples is not None: raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments") - if isinstance(trace, list) and all(isinstance(x, dict) for x in trace): - _trace = _TraceDict(point_list=trace) - elif isinstance(trace, MultiTrace): - _trace = _TraceDict(multi_trace=trace) + if isinstance(_trace0, list) and all((isinstance(x, dict) for x in _trace0)): + _trace = _TraceDict(point_list=_trace0) + elif isinstance(_trace0, MultiTrace): + _trace = _TraceDict(multi_trace=_trace0) else: raise TypeError( "Unable to generate posterior predictive samples from argument of type %s" @@ -305,6 +321,9 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None: k: ary.reshape((nchains, ndraws, *ary.shape[1:])) for k, ary in ppc_trace.items() } # this gets us a Dict[str, np.ndarray] instead of my wrapped equiv. + if add_to_inference_data: + assert isinstance(trace, InferenceData) + return az.concat(trace, az.from_dict(posterior_predictive=ppc_trace.data)) return ppc_trace.data diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index a415f648ea..f6a3f49ce5 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -400,6 +400,21 @@ def test_normal_scalar(self): assert ppc["a"].shape == (nchains, ndraws) ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True) assert ppc["a"].shape == (nchains, ndraws) + # test returning an InferenceData object + ppc_idata = pm.fast_sample_posterior_predictive(idata, add_to_inference_data=True) + assert isinstance(ppc_idata, az.InferenceData) + assert hasattr(ppc_idata, 'posterior_predictive') + assert set(ppc_idata.posterior_predictive.data_vars.keys()) == set(["a"]) + # if you pass an InferenceData, you get one back by default -- + # this may be controversial + ppc_idata = pm.fast_sample_posterior_predictive(idata) + assert isinstance(ppc_idata, az.InferenceData) + assert hasattr(ppc_idata, 'posterior_predictive') + assert set(ppc_idata.posterior_predictive.data_vars.keys()) == set(["a"]) + # but you don't have to get one if you don't want + ppc = pm.fast_sample_posterior_predictive(idata, keep_size=True, add_to_inference_data=False) + assert not isinstance(ppc, az.InferenceData) + assert ppc["a"].shape == (nchains, ndraws) # test default case ppc = pm.sample_posterior_predictive(trace, var_names=["a"]) @@ -500,6 +515,8 @@ def test_exceptions(self, caplog): ppc = pm.sample_posterior_predictive(bad_trace) with pytest.raises(TypeError): ppc = pm.fast_sample_posterior_predictive(bad_trace) + with pytest.raises(IncorrectArgumentsError): + ppc_idata = pm.fast_sample_posterior_predictive(trace, add_to_inference_data=True) def test_vector_observed(self): with pm.Model() as model: