|
48 | 48 | from pymc.distributions.distribution import Discrete
|
49 | 49 | from pymc.distributions.mixture import Mixture
|
50 | 50 | from pymc.distributions.shape_utils import rv_size_is_none
|
51 |
| -from pymc.logprob.basic import logp |
| 51 | +from pymc.logprob.basic import logcdf, logp |
52 | 52 | from pymc.math import sigmoid
|
53 | 53 | from pymc.pytensorf import floatX, intX
|
54 | 54 | from pymc.vartypes import continuous_types
|
@@ -823,6 +823,10 @@ def logcdf(value, p):
|
823 | 823 |
|
824 | 824 | def icdf(value, p):
|
825 | 825 | res = pt.ceil(pt.log1p(-value) / pt.log1p(-p)).astype("int64")
|
| 826 | + res_1m = pt.maximum(res - 1, 0) |
| 827 | + dist = pm.Geometric.dist(p=p) |
| 828 | + value_1m = pt.exp(logcdf(dist, res_1m)) |
| 829 | + res = pt.switch(value_1m >= value, res_1m, res) |
826 | 830 | res = check_icdf_value(res, value)
|
827 | 831 | return check_icdf_parameters(
|
828 | 832 | res,
|
@@ -1060,6 +1064,11 @@ def logcdf(value, lower, upper):
|
1060 | 1064 |
|
1061 | 1065 | def icdf(value, lower, upper):
|
1062 | 1066 | res = pt.ceil(value * (upper - lower + 1)).astype("int64") + lower - 1
|
| 1067 | + res_1m = pt.maximum(res - 1, lower) |
| 1068 | + dist = pm.DiscreteUniform.dist(lower=lower, upper=upper) |
| 1069 | + value_1m = pt.exp(logcdf(dist, res_1m)) |
| 1070 | + res = pt.switch(value_1m >= value, res_1m, res) |
| 1071 | + |
1063 | 1072 | res = check_icdf_value(res, value)
|
1064 | 1073 | return check_icdf_parameters(
|
1065 | 1074 | res,
|
|
0 commit comments