Skip to content

Commit b832ba2

Browse files
author
juan.lopez.arriaza
committed
Adding moment for Rice distribution and associated test
1 parent aae9dfb commit b832ba2

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

pymc/distributions/continuous.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3350,6 +3350,22 @@ def get_nu_b(cls, nu, b, sigma):
33503350
return nu, b, sigma
33513351
raise ValueError("Rice distribution must specify either nu" " or b.")
33523352

3353+
def get_moment(rv, size, nu, sigma):
3354+
nu_sigma_ratio = -(nu ** 2) / (2 * sigma ** 2)
3355+
mean = (
3356+
sigma
3357+
* np.sqrt(np.pi / 2)
3358+
* at.exp(nu_sigma_ratio / 2)
3359+
* (
3360+
(1 - nu_sigma_ratio) * at.i0(-nu_sigma_ratio / 2)
3361+
- nu_sigma_ratio * at.i1(-nu_sigma_ratio / 2)
3362+
)
3363+
)
3364+
3365+
if not rv_size_is_none(size):
3366+
mean = at.full(size, mean)
3367+
return mean
3368+
33533369
def logp(value, b, sigma):
33543370
"""
33553371
Calculate log-probability of Rice distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
Normal,
3737
Pareto,
3838
Poisson,
39+
Rice,
3940
SkewNormal,
4041
StudentT,
4142
Triangular,
@@ -757,7 +758,7 @@ def test_categorical_moment(p, size, expected):
757758
"mu, sigma, size, expected",
758759
[
759760
(4.0, 3.0, None, 7.8110885363844345),
760-
(4, np.full(5, 3), None, np.full(5, 7.8110885363844345)),
761+
(4.0, np.full(5, 3), None, np.full(5, 7.8110885363844345)),
761762
(np.arange(5), 1, None, np.arange(5) + 1.2703628454614782),
762763
(np.arange(5), np.ones(5), (2, 5), np.full((2, 5), np.arange(5) + 1.2703628454614782)),
763764
],
@@ -772,7 +773,7 @@ def test_moyal_moment(mu, sigma, size, expected):
772773
"alpha, mu, sigma, size, expected",
773774
[
774775
(1.0, 1.0, 1.0, None, 1.56418958),
775-
(1, np.ones(5), 1, None, np.full(5, 1.56418958)),
776+
(1.0, np.ones(5), 1.0, None, np.full(5, 1.56418958)),
776777
(np.ones(5), 1, np.ones(5), None, np.full(5, 1.56418958)),
777778
(
778779
np.arange(5),
@@ -822,3 +823,43 @@ def test_asymmetriclaplace_moment(b, kappa, mu, size, expected):
822823
with Model() as model:
823824
AsymmetricLaplace("x", b=b, kappa=kappa, mu=mu, size=size)
824825
assert_moment_is_expected(model, expected)
826+
827+
828+
@pytest.mark.parametrize(
829+
"nu, sigma, size, expected",
830+
[
831+
(1.0, 1.0, None, 1.5485724605511453),
832+
(1.0, np.ones(5), None, np.full(5, 1.5485724605511453)),
833+
(
834+
np.arange(1, 6),
835+
1.0,
836+
None,
837+
(
838+
1.5485724605511453,
839+
2.2723834280687427,
840+
3.1725772879007166,
841+
4.127193542536757,
842+
5.101069639492123,
843+
),
844+
),
845+
(
846+
np.arange(1, 6),
847+
np.ones(5),
848+
(2, 5),
849+
np.full(
850+
(2, 5),
851+
(
852+
1.5485724605511453,
853+
2.2723834280687427,
854+
3.1725772879007166,
855+
4.127193542536757,
856+
5.101069639492123,
857+
),
858+
),
859+
),
860+
],
861+
)
862+
def test_rice_moment(nu, sigma, size, expected):
863+
with Model() as model:
864+
Rice("x", nu=nu, sigma=sigma, size=size)
865+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)