Skip to content

Commit 45b3339

Browse files
authored
add moment for BART distribution (#5211)
* add moment * simplify
1 parent 64d8396 commit 45b3339

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

pymc/bart/bart.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
2020
from pandas import DataFrame, Series
2121

22-
from pymc.distributions.distribution import NoDistribution
22+
from pymc.distributions.distribution import NoDistribution, _get_moment
2323

2424
__all__ = ["BART"]
2525

@@ -110,6 +110,10 @@ def __new__(
110110

111111
NoDistribution.register(BARTRV)
112112

113+
@_get_moment.register(BARTRV)
114+
def get_moment(rv, size, *rv_inputs):
115+
return cls.get_moment(rv, size, *rv_inputs)
116+
113117
cls.rv_op = bart_op
114118
params = [X, Y, m, alpha, k]
115119
return super().__new__(cls, name, *params, **kwargs)
@@ -132,6 +136,11 @@ def logp(x, *inputs):
132136
"""
133137
return at.zeros_like(x)
134138

139+
@classmethod
140+
def get_moment(cls, rv, size, *rv_inputs):
141+
mean = at.fill(size, rv.Y.mean())
142+
return mean
143+
135144

136145
def preprocess_XY(X, Y):
137146
if isinstance(Y, (Series, DataFrame)):

pymc/tests/test_bart.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pymc as pm
88

9+
from pymc.tests.test_distributions_moments import assert_moment_is_expected
10+
911

1012
def test_split_node():
1113
split_node = pm.bart.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
@@ -97,3 +99,17 @@ def test_predict(self):
9799
)
98100
def test_pdp(self, kwargs):
99101
pm.bart.utils.plot_dependence(self.idata, X=self.X, Y=self.Y, **kwargs)
102+
103+
104+
@pytest.mark.parametrize(
105+
"size, expected",
106+
[
107+
(None, np.zeros(50)),
108+
],
109+
)
110+
def test_bart_moment(size, expected):
111+
X = np.zeros((50, 2))
112+
Y = np.zeros(50)
113+
with pm.Model() as model:
114+
pm.BART("x", X=X, Y=Y, size=size)
115+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)