diff --git a/pymc/bart/bart.py b/pymc/bart/bart.py index 783378d300..7c4d09c5f2 100644 --- a/pymc/bart/bart.py +++ b/pymc/bart/bart.py @@ -19,7 +19,7 @@ from aesara.tensor.random.op import RandomVariable, default_shape_from_params from pandas import DataFrame, Series -from pymc.distributions.distribution import NoDistribution +from pymc.distributions.distribution import NoDistribution, _get_moment __all__ = ["BART"] @@ -110,6 +110,10 @@ def __new__( NoDistribution.register(BARTRV) + @_get_moment.register(BARTRV) + def get_moment(rv, size, *rv_inputs): + return cls.get_moment(rv, size, *rv_inputs) + cls.rv_op = bart_op params = [X, Y, m, alpha, k] return super().__new__(cls, name, *params, **kwargs) @@ -132,6 +136,11 @@ def logp(x, *inputs): """ return at.zeros_like(x) + @classmethod + def get_moment(cls, rv, size, *rv_inputs): + mean = at.fill(size, rv.Y.mean()) + return mean + def preprocess_XY(X, Y): if isinstance(Y, (Series, DataFrame)): diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py index 901e4e4f91..3db5543908 100644 --- a/pymc/tests/test_bart.py +++ b/pymc/tests/test_bart.py @@ -6,6 +6,8 @@ import pymc as pm +from pymc.tests.test_distributions_moments import assert_moment_is_expected + def test_split_node(): split_node = pm.bart.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0) @@ -97,3 +99,17 @@ def test_predict(self): ) def test_pdp(self, kwargs): pm.bart.utils.plot_dependence(self.idata, X=self.X, Y=self.Y, **kwargs) + + +@pytest.mark.parametrize( + "size, expected", + [ + (None, np.zeros(50)), + ], +) +def test_bart_moment(size, expected): + X = np.zeros((50, 2)) + Y = np.zeros(50) + with pm.Model() as model: + pm.BART("x", X=X, Y=Y, size=size) + assert_moment_is_expected(model, expected)