Skip to content

sample_prior_predictive(var_names=[...]) results in KeyError inside to_inferencedata() #5337

Closed
@michaelosthege

Description

@michaelosthege

Description of your problem

Please provide a minimal, self-contained, and reproducible example.

with pm.Model() as model:
    x = pm.Normal("x")
    y = pm.Normal("y", x, observed=5)
    idata = pm.sample(tune=10, draws=20, chains=1, step=pm.Metropolis())
    pm.sample_posterior_predictive(idata, var_names=["x"]) # 👈 this works fine
    pm.sample_prior_predictive(var_names=["x"])          # 👈 this doesn't

Please provide the full traceback.

Complete error traceback
KeyError                                  Traceback (most recent call last)
<ipython-input-12-46d00c82b13d> in <module>
      4     idata = pm.sample(tune=10, draws=20, chains=1, step=pm.Metropolis())
      5     pm.sample_posterior_predictive(idata)
----> 6     pm.sample_prior_predictive(var_names=["x"])

c:\users\osthege\repos\pymc-main\pymc\sampling.py in sample_prior_predictive(samples, model, var_names, random_seed, mode, return_inferencedata, idata_kwargs)
   2030     if idata_kwargs:
   2031         ikwargs.update(idata_kwargs)
-> 2032     return pm.to_inference_data(prior=prior, **ikwargs)
   2033 
   2034 

c:\users\osthege\repos\pymc-main\pymc\backends\arviz.py in to_inference_data(trace, prior, posterior_predictive, log_likelihood, coords, dims, model, save_warmup, density_dist_obs)
    587         return trace
    588 
--> 589     return InferenceDataConverter(
    590         trace=trace,
    591         prior=prior,

c:\users\osthege\repos\pymc-main\pymc\backends\arviz.py in to_inference_data(self)
    523             "posterior_predictive": self.posterior_predictive_to_xarray(),
    524             "predictions": self.predictions_to_xarray(),
--> 525             **self.priors_to_xarray(),
    526             "observed_data": self.observed_data_to_xarray(),
    527         }

c:\users\osthege\repos\pymc-main\pymc\backends\arviz.py in priors_to_xarray(self)
    442                 if var_names is None
    443                 else dict_to_dataset(
--> 444                     {k: np.expand_dims(self.prior[k], 0) for k in var_names},
    445                     library=pymc,
    446                     coords=self.coords,

c:\users\osthege\repos\pymc-main\pymc\backends\arviz.py in <dictcomp>(.0)
    442                 if var_names is None
    443                 else dict_to_dataset(
--> 444                     {k: np.expand_dims(self.prior[k], 0) for k in var_names},
    445                     library=pymc,
    446                     coords=self.coords,

KeyError: 'y'

Versions and main components

  • PyMC/PyMC3 Version: main

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions