Skip to content

Commit 4b7aaad

Browse files
juan.lopez.arriazaricardoV94
juan.lopez.arriaza
authored andcommitted
Adding moment for Rice distribution and associated test
1 parent e257fe0 commit 4b7aaad

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

pymc/distributions/continuous.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3350,6 +3350,22 @@ def get_nu_b(cls, nu, b, sigma):
33503350
return nu, b, sigma
33513351
raise ValueError("Rice distribution must specify either nu" " or b.")
33523352

3353+
def get_moment(rv, size, nu, sigma):
3354+
nu_sigma_ratio = -(nu ** 2) / (2 * sigma ** 2)
3355+
mean = (
3356+
sigma
3357+
* np.sqrt(np.pi / 2)
3358+
* at.exp(nu_sigma_ratio / 2)
3359+
* (
3360+
(1 - nu_sigma_ratio) * at.i0(-nu_sigma_ratio / 2)
3361+
- nu_sigma_ratio * at.i1(-nu_sigma_ratio / 2)
3362+
)
3363+
)
3364+
3365+
if not rv_size_is_none(size):
3366+
mean = at.full(size, mean)
3367+
return mean
3368+
33533369
def logp(value, b, sigma):
33543370
"""
33553371
Calculate log-probability of Rice distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
Normal,
3737
Pareto,
3838
Poisson,
39+
Rice,
3940
SkewNormal,
4041
StudentT,
4142
Triangular,
@@ -792,7 +793,7 @@ def test_mv_normal_moment(mu, cov, size, expected):
792793
"mu, sigma, size, expected",
793794
[
794795
(4.0, 3.0, None, 7.8110885363844345),
795-
(4, np.full(5, 3), None, np.full(5, 7.8110885363844345)),
796+
(4.0, np.full(5, 3), None, np.full(5, 7.8110885363844345)),
796797
(np.arange(5), 1, None, np.arange(5) + 1.2703628454614782),
797798
(np.arange(5), np.ones(5), (2, 5), np.full((2, 5), np.arange(5) + 1.2703628454614782)),
798799
],
@@ -807,7 +808,7 @@ def test_moyal_moment(mu, sigma, size, expected):
807808
"alpha, mu, sigma, size, expected",
808809
[
809810
(1.0, 1.0, 1.0, None, 1.56418958),
810-
(1, np.ones(5), 1, None, np.full(5, 1.56418958)),
811+
(1.0, np.ones(5), 1.0, None, np.full(5, 1.56418958)),
811812
(np.ones(5), 1, np.ones(5), None, np.full(5, 1.56418958)),
812813
(
813814
np.arange(5),
@@ -857,3 +858,43 @@ def test_asymmetriclaplace_moment(b, kappa, mu, size, expected):
857858
with Model() as model:
858859
AsymmetricLaplace("x", b=b, kappa=kappa, mu=mu, size=size)
859860
assert_moment_is_expected(model, expected)
861+
862+
863+
@pytest.mark.parametrize(
864+
"nu, sigma, size, expected",
865+
[
866+
(1.0, 1.0, None, 1.5485724605511453),
867+
(1.0, np.ones(5), None, np.full(5, 1.5485724605511453)),
868+
(
869+
np.arange(1, 6),
870+
1.0,
871+
None,
872+
(
873+
1.5485724605511453,
874+
2.2723834280687427,
875+
3.1725772879007166,
876+
4.127193542536757,
877+
5.101069639492123,
878+
),
879+
),
880+
(
881+
np.arange(1, 6),
882+
np.ones(5),
883+
(2, 5),
884+
np.full(
885+
(2, 5),
886+
(
887+
1.5485724605511453,
888+
2.2723834280687427,
889+
3.1725772879007166,
890+
4.127193542536757,
891+
5.101069639492123,
892+
),
893+
),
894+
),
895+
],
896+
)
897+
def test_rice_moment(nu, sigma, size, expected):
898+
with Model() as model:
899+
Rice("x", nu=nu, sigma=sigma, size=size)
900+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)