Skip to content

Commit 687f044

Browse files
committed
Enable prior_predictive to return transformed values
1 parent 7e35cdd commit 687f044

File tree

3 files changed

+60
-3
lines changed

3 files changed

+60
-3
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
88
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`.
99
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc3/pull/4744)).
10-
- ...
10+
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc3/pull/4769)).
11+
...
1112

1213
### New Features
1314
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.

pymc3/sampling.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,7 +1943,8 @@ def sample_prior_predictive(
19431943
model : Model (optional if in ``with`` context)
19441944
var_names : Iterable[str]
19451945
A list of names of variables for which to compute the posterior predictive
1946-
samples. Defaults to both observed and unobserved RVs.
1946+
samples. Defaults to both observed and unobserved RVs. Transformed values
1947+
are not included unless explicitly defined in var_names.
19471948
random_seed : int
19481949
Seed for the random number generator.
19491950
mode:
@@ -1983,8 +1984,26 @@ def sample_prior_predictive(
19831984
)
19841985

19851986
names = get_default_varnames(vars_, include_transformed=False)
1986-
19871987
vars_to_sample = [model[name] for name in names]
1988+
1989+
# Any variables from var_names that are missing must be transformed variables.
1990+
# Misspelled variables would have raised a KeyError above.
1991+
missing_names = vars_.difference(names)
1992+
for name in missing_names:
1993+
transformed_value_var = model[name]
1994+
rv_var = model.values_to_rvs[transformed_value_var]
1995+
transform = transformed_value_var.tag.transform
1996+
transformed_rv_var = transform.forward(rv_var, rv_var)
1997+
1998+
names.append(name)
1999+
vars_to_sample.append(transformed_rv_var)
2000+
2001+
# If the user asked for the transformed variable in var_names, but not the
2002+
# original RV, we add it manually here
2003+
if rv_var.name not in names:
2004+
names.append(rv_var.name)
2005+
vars_to_sample.append(rv_var)
2006+
19882007
inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, SharedVariable)]
19892008

19902009
sampler_fn = compile_rv_inplace(

pymc3/tests/test_sampling.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,43 @@ def test_potentials_warning(self):
10761076
with pytest.warns(UserWarning, match=warning_msg):
10771077
pm.sample_prior_predictive(samples=5)
10781078

1079+
def test_transformed_vars(self):
1080+
# Test that prior predictive returns transformation of RVs when these are
1081+
# passed explicitly in `var_names`
1082+
1083+
def ub_interval_forward(x, ub):
1084+
# Interval transform assuming lower bound is zero
1085+
return np.log(x - 0) - np.log(ub - x)
1086+
1087+
with pm.Model(rng_seeder=123) as model:
1088+
ub = pm.HalfNormal("ub", 10)
1089+
x = pm.Uniform("x", 0, ub)
1090+
1091+
prior = pm.sample_prior_predictive(
1092+
var_names=["ub", "ub_log__", "x", "x_interval__"],
1093+
samples=10,
1094+
)
1095+
1096+
# Check values are correct
1097+
assert np.allclose(prior["ub_log__"], np.log(prior["ub"]))
1098+
assert np.allclose(
1099+
prior["x_interval__"],
1100+
ub_interval_forward(prior["x"], prior["ub"]),
1101+
)
1102+
1103+
# Check that it works when the original RVs are not mentioned in var_names
1104+
with pm.Model(rng_seeder=123) as model_transformed_only:
1105+
ub = pm.HalfNormal("ub", 10)
1106+
x = pm.Uniform("x", 0, ub)
1107+
1108+
prior_transformed_only = pm.sample_prior_predictive(
1109+
var_names=["ub_log__", "x_interval__"],
1110+
samples=10,
1111+
)
1112+
assert "ub" not in prior_transformed_only and "x" not in prior_transformed_only
1113+
assert np.allclose(prior["ub_log__"], prior_transformed_only["ub_log__"])
1114+
assert np.allclose(prior["x_interval__"], prior_transformed_only["x_interval__"])
1115+
10791116

10801117
class TestSamplePosteriorPredictive:
10811118
def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):

0 commit comments

Comments
 (0)