Skip to content

Commit 23753dd

Browse files
authored
Merge branch 'pymc-devs:main' into MvStudentT
2 parents 96cb9a4 + 45b3339 commit 23753dd

File tree

9 files changed

+251
-43
lines changed

9 files changed

+251
-43
lines changed

pymc/backends/arviz.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,26 @@
4242
Var = Any # pylint: disable=invalid-name
4343

4444

45+
def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
46+
"""If there are observations available, return them as a dictionary."""
47+
if model is None:
48+
return None
49+
50+
observations = {}
51+
for obs in model.observed_RVs:
52+
aux_obs = getattr(obs.tag, "observations", None)
53+
if aux_obs is not None:
54+
try:
55+
obs_data = extract_obs_data(aux_obs)
56+
observations[obs.name] = obs_data
57+
except TypeError:
58+
warnings.warn(f"Could not extract data from symbolic observation {obs}")
59+
else:
60+
warnings.warn(f"No data for observation {obs}")
61+
62+
return observations
63+
64+
4565
class _DefaultTrace:
4666
"""
4767
Utility for collecting samples into a dictionary.
@@ -196,25 +216,7 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
196216
self.dims = {**model_dims, **self.dims}
197217

198218
self.density_dist_obs = density_dist_obs
199-
self.observations = self.find_observations()
200-
201-
def find_observations(self) -> Optional[Dict[str, Var]]:
202-
"""If there are observations available, return them as a dictionary."""
203-
if self.model is None:
204-
return None
205-
observations = {}
206-
for obs in self.model.observed_RVs:
207-
aux_obs = getattr(obs.tag, "observations", None)
208-
if aux_obs is not None:
209-
try:
210-
obs_data = extract_obs_data(aux_obs)
211-
observations[obs.name] = obs_data
212-
except TypeError:
213-
warnings.warn(f"Could not extract data from symbolic observation {obs}")
214-
else:
215-
warnings.warn(f"No data for observation {obs}")
216-
217-
return observations
219+
self.observations = find_observations(self.model)
218220

219221
def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
220222
"""Split MultiTrace object into posterior and warmup.

pymc/bart/bart.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
2020
from pandas import DataFrame, Series
2121

22-
from pymc.distributions.distribution import NoDistribution
22+
from pymc.distributions.distribution import NoDistribution, _get_moment
2323

2424
__all__ = ["BART"]
2525

@@ -110,6 +110,10 @@ def __new__(
110110

111111
NoDistribution.register(BARTRV)
112112

113+
@_get_moment.register(BARTRV)
114+
def get_moment(rv, size, *rv_inputs):
115+
return cls.get_moment(rv, size, *rv_inputs)
116+
113117
cls.rv_op = bart_op
114118
params = [X, Y, m, alpha, k]
115119
return super().__new__(cls, name, *params, **kwargs)
@@ -132,6 +136,11 @@ def logp(x, *inputs):
132136
"""
133137
return at.zeros_like(x)
134138

139+
@classmethod
140+
def get_moment(cls, rv, size, *rv_inputs):
141+
mean = at.fill(size, rv.Y.mean())
142+
return mean
143+
135144

136145
def preprocess_XY(X, Y):
137146
if isinstance(Y, (Series, DataFrame)):

pymc/distributions/continuous.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,23 +2334,19 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, *args, **kwar
23342334
alpha = at.as_tensor_variable(floatX(alpha))
23352335
beta = at.as_tensor_variable(floatX(beta))
23362336

2337-
# m = beta / (alpha - 1.0)
2338-
# try:
2339-
# mean = (alpha > 1) * m or np.inf
2340-
# except ValueError: # alpha is an array
2341-
# m[alpha <= 1] = np.inf
2342-
# mean = m
2343-
2344-
# mode = beta / (alpha + 1.0)
2345-
# variance = at.switch(
2346-
# at.gt(alpha, 2), (beta ** 2) / ((alpha - 2) * (alpha - 1.0) ** 2), np.inf
2347-
# )
2348-
23492337
assert_negative_support(alpha, "alpha", "InverseGamma")
23502338
assert_negative_support(beta, "beta", "InverseGamma")
23512339

23522340
return super().dist([alpha, beta], **kwargs)
23532341

2342+
def get_moment(rv, size, alpha, beta):
2343+
mean = beta / (alpha - 1.0)
2344+
mode = beta / (alpha + 1.0)
2345+
moment = at.switch(alpha > 1, mean, mode)
2346+
if not rv_size_is_none(size):
2347+
moment = at.full(size, moment)
2348+
return moment
2349+
23542350
@classmethod
23552351
def _get_alpha_beta(cls, alpha, beta, mu, sigma):
23562352
if alpha is not None:

pymc/distributions/discrete.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,7 +1462,7 @@ class ZeroInflatedBinomial(Discrete):
14621462
14631463
======== ==========================
14641464
Support :math:`x \in \mathbb{N}_0`
1465-
Mean :math:`(1 - \psi) n p`
1465+
Mean :math:`\psi n p`
14661466
Variance :math:`(1-\psi) n p [1 - p(1 - \psi n)].`
14671467
======== ==========================
14681468
@@ -1487,7 +1487,7 @@ def dist(cls, psi, n, p, *args, **kwargs):
14871487
return super().dist([psi, n, p], *args, **kwargs)
14881488

14891489
def get_moment(rv, size, psi, n, p):
1490-
mean = at.round((1 - psi) * n * p)
1490+
mean = at.round(psi * n * p)
14911491
if not rv_size_is_none(size):
14921492
mean = at.full(size, mean)
14931493
return mean
@@ -1650,6 +1650,12 @@ def dist(cls, psi, mu, alpha, *args, **kwargs):
16501650
p = at.as_tensor_variable(floatX(p))
16511651
return super().dist([psi, n, p], *args, **kwargs)
16521652

1653+
def get_moment(rv, size, psi, n, p):
1654+
mean = at.floor(psi * n * (1 - p) / p)
1655+
if not rv_size_is_none(size):
1656+
mean = at.full(size, mean)
1657+
return mean
1658+
16531659
def logp(value, psi, n, p):
16541660
r"""
16551661
Calculate log-probability of ZeroInflatedNegativeBinomial distribution at specified value.

pymc/distributions/simulator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from scipy.spatial import cKDTree
2626

2727
from pymc.aesaraf import floatX
28-
from pymc.distributions.distribution import NoDistribution
28+
from pymc.distributions.distribution import NoDistribution, _get_moment
2929

3030
__all__ = ["Simulator"]
3131

@@ -223,13 +223,23 @@ def logp(op, value_var_list, *dist_params, **kwargs):
223223
value_var = value_var_list[0]
224224
return cls.logp(value_var, op, dist_params)
225225

226+
@_get_moment.register(SimulatorRV)
227+
def get_moment(op, rv, size, *rv_inputs):
228+
return cls.get_moment(rv, size, *rv_inputs)
229+
226230
cls.rv_op = sim_op
227231
return super().__new__(cls, name, *params, **kwargs)
228232

229233
@classmethod
230234
def dist(cls, *params, **kwargs):
231235
return super().dist(params, **kwargs)
232236

237+
@classmethod
238+
def get_moment(cls, rv, size, *sim_inputs):
239+
# Take the mean of 10 draws
240+
multiple_sim = rv.owner.op(*sim_inputs, size=at.concatenate([[10], rv.shape]))
241+
return at.mean(multiple_sim, axis=0)
242+
233243
@classmethod
234244
def logp(cls, value, sim_op, sim_inputs):
235245
# Use a new rng to avoid non-randomness in parallel sampling

pymc/sampling_jax.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from aesara.link.jax.dispatch import jax_funcify
2727

2828
from pymc import Model, modelcontext
29-
from pymc.aesaraf import compile_rv_inplace, inputvars
29+
from pymc.aesaraf import compile_rv_inplace
30+
from pymc.backends.arviz import find_observations
31+
from pymc.distributions import logpt
3032
from pymc.util import get_default_varnames
3133

3234
warnings.warn("This module is experimental.")
@@ -95,6 +97,40 @@ def logp_fn_wrap(x):
9597
return logp_fn_wrap
9698

9799

100+
# Adopted from arviz numpyro extractor
101+
def _sample_stats_to_xarray(posterior):
102+
"""Extract sample_stats from NumPyro posterior."""
103+
rename_key = {
104+
"potential_energy": "lp",
105+
"adapt_state.step_size": "step_size",
106+
"num_steps": "n_steps",
107+
"accept_prob": "acceptance_rate",
108+
}
109+
data = {}
110+
for stat, value in posterior.get_extra_fields(group_by_chain=True).items():
111+
if isinstance(value, (dict, tuple)):
112+
continue
113+
name = rename_key.get(stat, stat)
114+
value = value.copy()
115+
data[name] = value
116+
if stat == "num_steps":
117+
data["tree_depth"] = np.log2(value).astype(int) + 1
118+
return data
119+
120+
121+
def _get_log_likelihood(model, samples):
122+
"Compute log-likelihood for all observations"
123+
data = {}
124+
for v in model.observed_RVs:
125+
logp_v = replace_shared_variables([logpt(v)])
126+
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
127+
optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
128+
jax_fn = jax_funcify(fgraph)
129+
result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]
130+
data[v.name] = result
131+
return data
132+
133+
98134
def sample_numpyro_nuts(
99135
draws=1000,
100136
tune=1000,
@@ -115,6 +151,20 @@ def sample_numpyro_nuts(
115151

116152
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
117153

154+
coords = {
155+
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
156+
for cname, cvals in model.coords.items()
157+
if cvals is not None
158+
}
159+
160+
if hasattr(model, "RV_dims"):
161+
dims = {
162+
var_name: [dim for dim in dims if dim is not None]
163+
for var_name, dims in model.RV_dims.items()
164+
}
165+
else:
166+
dims = {}
167+
118168
tic1 = pd.Timestamp.now()
119169
print("Compiling...", file=sys.stdout)
120170

@@ -151,9 +201,23 @@ def sample_numpyro_nuts(
151201
map_seed = jax.random.split(seed, chains)
152202

153203
if chains == 1:
154-
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",))
204+
init_params = init_state
205+
map_seed = seed
155206
else:
156-
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
207+
init_params = init_state_batched
208+
209+
pmap_numpyro.run(
210+
map_seed,
211+
init_params=init_params,
212+
extra_fields=(
213+
"num_steps",
214+
"potential_energy",
215+
"energy",
216+
"adapt_state.step_size",
217+
"accept_prob",
218+
"diverging",
219+
),
220+
)
157221

158222
raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
159223

@@ -164,6 +228,7 @@ def sample_numpyro_nuts(
164228
mcmc_samples = {}
165229
for v in vars_to_sample:
166230
fgraph = FunctionGraph(model.value_vars, [v], clone=False)
231+
optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
167232
jax_fn = jax_funcify(fgraph)
168233
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
169234
mcmc_samples[v.name] = result
@@ -172,6 +237,13 @@ def sample_numpyro_nuts(
172237
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
173238

174239
posterior = mcmc_samples
175-
az_trace = az.from_dict(posterior=posterior)
240+
az_trace = az.from_dict(
241+
posterior=posterior,
242+
log_likelihood=_get_log_likelihood(model, raw_mcmc_samples),
243+
observed_data=find_observations(model),
244+
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
245+
coords=coords,
246+
dims=dims,
247+
)
176248

177249
return az_trace

pymc/tests/test_bart.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pymc as pm
88

9+
from pymc.tests.test_distributions_moments import assert_moment_is_expected
10+
911

1012
def test_split_node():
1113
split_node = pm.bart.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
@@ -97,3 +99,17 @@ def test_predict(self):
9799
)
98100
def test_pdp(self, kwargs):
99101
pm.bart.utils.plot_dependence(self.idata, X=self.X, Y=self.Y, **kwargs)
102+
103+
104+
@pytest.mark.parametrize(
105+
"size, expected",
106+
[
107+
(None, np.zeros(50)),
108+
],
109+
)
110+
def test_bart_moment(size, expected):
111+
X = np.zeros((50, 2))
112+
Y = np.zeros(50)
113+
with pm.Model() as model:
114+
pm.BART("x", X=X, Y=Y, size=size)
115+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)