Skip to content

Commit 77c9602

Browse files
authored
Merge branch 'main' into main
2 parents e1b8f0c + 93c2293 commit 77c9602

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

pymc/distributions/continuous.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2634,16 +2634,18 @@ def dist(cls, nu=1, sigma=None, lam=None, sd=None, *args, **kwargs):
26342634
lam, sigma = get_tau_sigma(lam, sigma)
26352635
sigma = at.as_tensor_variable(sigma)
26362636

2637-
# mode = at.as_tensor_variable(0)
2638-
# median = at.as_tensor_variable(sigma)
2639-
# sd = at.as_tensor_variable(sigma)
2640-
26412637
assert_negative_support(nu, "nu", "HalfStudentT")
26422638
assert_negative_support(lam, "lam", "HalfStudentT")
26432639
assert_negative_support(sigma, "sigma", "HalfStudentT")
26442640

26452641
return super().dist([nu, sigma], *args, **kwargs)
26462642

2643+
def get_moment(rv, size, nu, sigma):
2644+
sigma, _ = at.broadcast_arrays(sigma, nu)
2645+
if not rv_size_is_none(size):
2646+
sigma = at.full(size, sigma)
2647+
return sigma
2648+
26472649
def logp(value, nu, sigma):
26482650
"""
26492651
Calculate log-probability of HalfStudentT distribution at specified value.

pymc/distributions/discrete.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,11 @@ def dist(cls, c, *args, **kwargs):
12171217
c = at.as_tensor_variable(floatX(c))
12181218
return super().dist([c], **kwargs)
12191219

1220+
def get_moment(rv, size, c):
1221+
if not rv_size_is_none(size):
1222+
c = at.full(size, c)
1223+
return c
1224+
12201225
def logp(value, c):
12211226
r"""
12221227
Calculate log-probability of Constant distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
Binomial,
1010
Cauchy,
1111
ChiSquared,
12+
Constant,
1213
Exponential,
1314
Gamma,
1415
HalfCauchy,
1516
HalfNormal,
17+
HalfStudentT,
1618
Kumaraswamy,
1719
Laplace,
1820
Logistic,
@@ -131,6 +133,21 @@ def test_halfnormal_moment(sigma, size, expected):
131133
assert_moment_is_expected(model, expected)
132134

133135

136+
@pytest.mark.parametrize(
137+
"nu, sigma, size, expected",
138+
[
139+
(1, 1, None, 1),
140+
(1, 1, 5, np.ones(5)),
141+
(1, np.arange(5), (2, 5), np.full((2, 5), np.arange(5))),
142+
(np.arange(1, 6), 1, None, np.full(5, 1)),
143+
],
144+
)
145+
def test_halfstudentt_moment(nu, sigma, size, expected):
146+
with Model() as model:
147+
HalfStudentT("x", nu=nu, sigma=sigma, size=size)
148+
assert_moment_is_expected(model, expected)
149+
150+
134151
@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None")
135152
@pytest.mark.parametrize(
136153
"mu, sigma, lower, upper, size, expected",
@@ -385,6 +402,20 @@ def test_poisson_moment(mu, size, expected):
385402
assert_moment_is_expected(model, expected)
386403

387404

405+
@pytest.mark.parametrize(
406+
"c, size, expected",
407+
[
408+
(1, None, 1),
409+
(1, 5, np.full(5, 1)),
410+
(np.arange(1, 6), None, np.arange(1, 6)),
411+
],
412+
)
413+
def test_constant_moment(c, size, expected):
414+
with Model() as model:
415+
Constant("x", c=c, size=size)
416+
assert_moment_is_expected(model, expected)
417+
418+
388419
@pytest.mark.parametrize(
389420
"mu, s, size, expected",
390421
[
@@ -403,3 +434,4 @@ def test_logistic_moment(mu, s, size, expected):
403434
with Model() as model:
404435
Logistic("x", mu=mu, s=s, size=size)
405436
assert_moment_is_expected(model, expected)
437+

0 commit comments

Comments
 (0)