Skip to content

Commit bb09a5a

Browse files
Add get_moment implementations for Normal, Uniform and Binomial
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 2b9e465 commit bb09a5a

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

pymc/distributions/continuous.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ def logcdf(value, lower, upper):
336336
),
337337
)
338338

339+
def get_moment(value, size, lower, upper):
340+
lower = at.full(size, lower, dtype=aesara.config.floatX)
341+
upper = at.full(size, upper, dtype=aesara.config.floatX)
342+
return (lower + upper) / 2
343+
339344

340345
class FlatRV(RandomVariable):
341346
name = "flat"
@@ -366,7 +371,7 @@ def dist(cls, *, size=None, **kwargs):
366371
res.tag.test_value = np.full(size, floatX(0.0))
367372
return res
368373

369-
def get_moment(rv, size, *rv_inputs) -> np.ndarray:
374+
def get_moment(rv, size, *rv_inputs):
370375
return at.zeros(size, dtype=aesara.config.floatX)
371376

372377
def logp(value):
@@ -431,7 +436,7 @@ def dist(cls, *, size=None, **kwargs):
431436
res.tag.test_value = np.full(size, floatX(1.0))
432437
return res
433438

434-
def get_moment(value_var, size, *rv_inputs) -> np.ndarray:
439+
def get_moment(value_var, size, *rv_inputs):
435440
return at.ones(size, dtype=aesara.config.floatX)
436441

437442
def logp(value):
@@ -588,6 +593,9 @@ def logcdf(value, mu, sigma):
588593
0 < sigma,
589594
)
590595

596+
def get_moment(value_var, size, mu, sigma):
597+
return at.full(size, mu, dtype=aesara.config.floatX)
598+
591599

592600
class TruncatedNormalRV(RandomVariable):
593601
name = "truncated_normal"

pymc/distributions/discrete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ def logcdf(value, p):
394394
p <= 1,
395395
)
396396

397+
def get_moment(value, size, p):
398+
p = at.full(size, p)
399+
return at.switch(p < 0.5, np.int64(0), np.int64(1))
400+
397401
def _distr_parameters_for_repr(self):
398402
return ["p"]
399403

pymc/tests/test_initvals.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ def test_automatically_assigned_test_values(self):
9595

9696
class TestMoment:
9797
def test_basic(self):
98+
# Standard distributions
99+
rv = pm.Normal.dist(mu=2.3)
100+
np.testing.assert_allclose(get_moment(rv).eval(), 2.3)
101+
102+
# Special distributions
98103
rv = pm.Flat.dist()
99104
assert get_moment(rv).eval() == np.zeros(())
100105
rv = pm.HalfFlat.dist()

0 commit comments

Comments
 (0)