Skip to content

Commit b83ffca

Browse files
committed
Fix dtype casting bug in icdf function
1 parent 2a324bc commit b83ffca

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pymc/logprob/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def icdf(
127127
rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs
128128
) -> TensorVariable:
129129
"""Create a graph for the inverse CDF of a Random Variable."""
130-
value = pt.as_tensor_variable(value, dtype=rv.dtype)
130+
value = pt.as_tensor_variable(value, dtype="floatX")
131131
try:
132132
return _icdf_helper(rv, value, **kwargs)
133133
except NotImplementedError:

tests/logprob/test_basic.py

+11
Original file line numberDiff line numberDiff line change
@@ -510,3 +510,14 @@ def test_warn_random_found_probability_inference(func, scipy_func, test_value):
510510
test_value
511511
),
512512
)
513+
514+
515+
def test_icdf_discrete():
516+
p = 0.1
517+
value = 0.9
518+
dist = pm.Geometric.dist(p=p)
519+
dist_icdf = icdf(dist, value)
520+
np.testing.assert_almost_equal(
521+
dist_icdf.eval(),
522+
sp.geom.ppf(value, p),
523+
)

0 commit comments

Comments
 (0)