Skip to content

Commit 741f207

Browse files
azihnaAlihan ZihnaricardoV94
authored
Add HalfStudentT moment (#5152)
Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Alihan Zihna <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 45d48a6 commit 741f207

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

pymc/distributions/continuous.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2634,16 +2634,18 @@ def dist(cls, nu=1, sigma=None, lam=None, sd=None, *args, **kwargs):
26342634
lam, sigma = get_tau_sigma(lam, sigma)
26352635
sigma = at.as_tensor_variable(sigma)
26362636

2637-
# mode = at.as_tensor_variable(0)
2638-
# median = at.as_tensor_variable(sigma)
2639-
# sd = at.as_tensor_variable(sigma)
2640-
26412637
assert_negative_support(nu, "nu", "HalfStudentT")
26422638
assert_negative_support(lam, "lam", "HalfStudentT")
26432639
assert_negative_support(sigma, "sigma", "HalfStudentT")
26442640

26452641
return super().dist([nu, sigma], *args, **kwargs)
26462642

2643+
def get_moment(rv, size, nu, sigma):
2644+
sigma, _ = at.broadcast_arrays(sigma, nu)
2645+
if not rv_size_is_none(size):
2646+
sigma = at.full(size, sigma)
2647+
return sigma
2648+
26472649
def logp(value, nu, sigma):
26482650
"""
26492651
Calculate log-probability of HalfStudentT distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Gamma,
1414
HalfCauchy,
1515
HalfNormal,
16+
HalfStudentT,
1617
Kumaraswamy,
1718
Laplace,
1819
LogNormal,
@@ -130,6 +131,21 @@ def test_halfnormal_moment(sigma, size, expected):
130131
assert_moment_is_expected(model, expected)
131132

132133

134+
@pytest.mark.parametrize(
135+
"nu, sigma, size, expected",
136+
[
137+
(1, 1, None, 1),
138+
(1, 1, 5, np.ones(5)),
139+
(1, np.arange(5), (2, 5), np.full((2, 5), np.arange(5))),
140+
(np.arange(1, 6), 1, None, np.full(5, 1)),
141+
],
142+
)
143+
def test_halfstudentt_moment(nu, sigma, size, expected):
144+
with Model() as model:
145+
HalfStudentT("x", nu=nu, sigma=sigma, size=size)
146+
assert_moment_is_expected(model, expected)
147+
148+
133149
@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None")
134150
@pytest.mark.parametrize(
135151
"mu, sigma, lower, upper, size, expected",

0 commit comments

Comments
 (0)