Skip to content

Commit 5cf7efe

Browse files
gokuldricardoV94
authored andcommitted
Fix numerical precision issues in discrete ICDFs.
1 parent 910d9ef commit 5cf7efe

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

pymc/distributions/discrete.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from pymc.distributions.distribution import Discrete
4949
from pymc.distributions.mixture import Mixture
5050
from pymc.distributions.shape_utils import rv_size_is_none
51-
from pymc.logprob.basic import logp
51+
from pymc.logprob.basic import logcdf, logp
5252
from pymc.math import sigmoid
5353
from pymc.pytensorf import floatX, intX
5454
from pymc.vartypes import continuous_types
@@ -823,6 +823,10 @@ def logcdf(value, p):
823823

824824
def icdf(value, p):
825825
res = pt.ceil(pt.log1p(-value) / pt.log1p(-p)).astype("int64")
826+
res_1m = pt.maximum(res - 1, 0)
827+
dist = pm.Geometric.dist(p=p)
828+
value_1m = pt.exp(logcdf(dist, res_1m))
829+
res = pt.switch(value_1m >= value, res_1m, res)
826830
res = check_icdf_value(res, value)
827831
return check_icdf_parameters(
828832
res,
@@ -1060,6 +1064,11 @@ def logcdf(value, lower, upper):
10601064

10611065
def icdf(value, lower, upper):
10621066
res = pt.ceil(value * (upper - lower + 1)).astype("int64") + lower - 1
1067+
res_1m = pt.maximum(res - 1, lower)
1068+
dist = pm.DiscreteUniform.dist(lower=lower, upper=upper)
1069+
value_1m = pt.exp(logcdf(dist, res_1m))
1070+
res = pt.switch(value_1m >= value, res_1m, res)
1071+
10631072
res = check_icdf_value(res, value)
10641073
return check_icdf_parameters(
10651074
res,

0 commit comments

Comments
 (0)