diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index c325c6a33a..b676b64c97 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -15,7 +15,7 @@ from pymc3.theanof import floatX from . import transforms -from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise +from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise, zvalue from .distribution import Continuous, draw_values, generate_samples, Bound __all__ = ['Uniform', 'Flat', 'Normal', 'Beta', 'Exponential', 'Laplace', @@ -231,6 +231,44 @@ def logp(self, value): return bound((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2., sd > 0) + def cdf(self, value): + mu = self.mu + sd = self.sd + z = zvalue(value, mu=mu, sd=sd) + + return tt.erfc(-z / tt.sqrt(2.))/2. + + def ccdf(self, value): + mu = self.mu + sd = self.sd + z = zvalue(value, mu=mu, sd=sd) + + return tt.erfc(z / tt.sqrt(2.))/2. + + def lcdf(self, value): + mu = self.mu + sd = self.sd + z = zvalue(value, mu=mu, sd=sd) + + return tt.switch( + tt.lt(z, -1.0), + tt.log(tt.erfcx(-z / tt.sqrt(2.)) / 2.) - + tt.sqr(tt.abs_(z)) / 2, + tt.log1p(-tt.erfc(z / tt.sqrt(2.)) / 2.) + ) + + def lccdf(self, value): + mu = self.mu + sd = self.sd + z = zvalue(value, mu=mu, sd=sd) + + return tt.switch( + tt.gt(z, 1.0), + tt.log(tt.erfcx(z / tt.sqrt(2.)) / 2) - + tt.sqr(tt.abs_(z)) / 2., + tt.log1p(-tt.erfc(-z / tt.sqrt(2.)) / 2.) + ) + class HalfNormal(PositiveContinuous): R""" diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index 22aa4c6085..8a1f848ca3 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -96,3 +96,9 @@ def i1(x): x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600, np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3) + 14175 / (98304 * x**4))) + +def zvalue(value, sd=1, mu=0): + """ + Calculate the z-value for a normal distribution. By default standard normal. + """ + return (value - mu) / tt.sqrt(2. * sd ** 2.)