Skip to content

Commit f12b1fe

Browse files
Only pack variables for which prior samples are available (#5338)
Closes #5337
1 parent 18e5e28 commit f12b1fe

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pymc/backends/arviz.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def priors_to_xarray(self):
412412
if self.prior is None:
413413
return {"prior": None, "prior_predictive": None}
414414
if self.observations is not None:
415-
prior_predictive_vars = list(self.observations.keys())
415+
prior_predictive_vars = list(set(self.observations).intersection(self.prior))
416416
prior_vars = [key for key in self.prior.keys() if key not in prior_predictive_vars]
417417
else:
418418
prior_vars = list(self.prior.keys())

pymc/tests/test_idata_conversion.py

+11
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,17 @@ def test_priors_separation(self, use_context):
572572
fails = check_multiple_attrs(test_dict, inference_data)
573573
assert not fails
574574

575+
def test_conversion_from_variables_subset(self):
576+
"""This is a regression test for issue #5337."""
577+
with pm.Model() as model:
578+
x = pm.Normal("x")
579+
pm.Normal("y", x, observed=5)
580+
idata = pm.sample(
581+
tune=10, draws=20, chains=1, step=pm.Metropolis(), compute_convergence_checks=False
582+
)
583+
pm.sample_posterior_predictive(idata, var_names=["x"])
584+
pm.sample_prior_predictive(var_names=["x"])
585+
575586
def test_multivariate_observations(self):
576587
coords = {"direction": ["x", "y", "z"], "experiment": np.arange(20)}
577588
data = np.random.multinomial(20, [0.2, 0.3, 0.5], size=20)

0 commit comments

Comments
 (0)