Skip to content

Commit 7c4d5ab

Browse files
Alihan ZihnaricardoV94
Alihan Zihna
authored andcommitted
Add wald and pareto moments
1 parent 31b4a37 commit 7c4d5ab

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

pymc/distributions/continuous.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,12 @@ def dist(
995995

996996
return super().dist([mu, lam, alpha], **kwargs)
997997

998+
def get_moment(rv, size, mu, lam, alpha):
999+
mu, _, _ = at.broadcast_arrays(mu, lam, alpha)
1000+
if not rv_size_is_none(size):
1001+
mu = at.full(size, mu)
1002+
return mu
1003+
9981004
@staticmethod
9991005
def get_mu_lam_phi(
10001006
mu: Optional[float], lam: Optional[float], phi: Optional[float]
@@ -1943,8 +1949,11 @@ def dist(
19431949

19441950
return super().dist([alpha, m], **kwargs)
19451951

1946-
def _distr_parameters_for_repr(self):
1947-
return ["alpha", "m"]
1952+
def get_moment(rv, size, alpha, m):
1953+
median = m * 2 ** (1 / alpha)
1954+
if not rv_size_is_none(size):
1955+
median = at.full(size, median)
1956+
return median
19481957

19491958
def logcdf(
19501959
value: Union[float, np.ndarray, TensorVariable],

pymc/tests/test_distributions_moments.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626
LogNormal,
2727
NegativeBinomial,
2828
Normal,
29+
Pareto,
2930
Poisson,
3031
StudentT,
3132
TruncatedNormal,
3233
Uniform,
34+
Wald,
3335
Weibull,
3436
ZeroInflatedBinomial,
3537
ZeroInflatedPoisson,
@@ -360,6 +362,42 @@ def test_gamma_moment(alpha, beta, size, expected):
360362
assert_moment_is_expected(model, expected)
361363

362364

365+
@pytest.mark.parametrize(
366+
"alpha, m, size, expected",
367+
[
368+
(2, 1, None, 1 * 2 ** (1 / 2)),
369+
(2, 1, 5, np.full(5, 1 * 2 ** (1 / 2))),
370+
(np.arange(2, 7), np.arange(1, 6), None, np.arange(1, 6) * 2 ** (1 / np.arange(2, 7))),
371+
(
372+
np.arange(2, 7),
373+
np.arange(1, 6),
374+
(2, 5),
375+
np.full((2, 5), np.arange(1, 6) * 2 ** (1 / np.arange(2, 7))),
376+
),
377+
],
378+
)
379+
def test_pareto_moment(alpha, m, size, expected):
380+
with Model() as model:
381+
Pareto("x", alpha=alpha, m=m, size=size)
382+
assert_moment_is_expected(model, expected)
383+
384+
385+
@pytest.mark.parametrize(
386+
"mu, lam, phi, size, expected",
387+
[
388+
(2, None, None, None, 2),
389+
(None, 1, 1, 5, np.full(5, 1)),
390+
(1, None, np.ones(5), None, np.full(5, 1)),
391+
(3, np.full(5, 2), None, None, np.full(5, 3)),
392+
(np.arange(1, 6), None, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(1, 6))),
393+
],
394+
)
395+
def test_wald_moment(mu, lam, phi, size, expected):
396+
with Model() as model:
397+
Wald("x", mu=mu, lam=lam, phi=phi, size=size)
398+
assert_moment_is_expected(model, expected)
399+
400+
363401
@pytest.mark.parametrize(
364402
"alpha, beta, size, expected",
365403
[

0 commit comments

Comments
 (0)