From b83ffcade554bcdb57b108e10eba47c7c80b422f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Apr 2023 10:29:07 +0200 Subject: [PATCH] Fix dtype casting bug in icdf function --- pymc/logprob/basic.py | 2 +- tests/logprob/test_basic.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index bb2126e33e..a8d4221f06 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -127,7 +127,7 @@ def icdf( rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs ) -> TensorVariable: """Create a graph for the inverse CDF of a Random Variable.""" - value = pt.as_tensor_variable(value, dtype=rv.dtype) + value = pt.as_tensor_variable(value, dtype="floatX") try: return _icdf_helper(rv, value, **kwargs) except NotImplementedError: diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index 92e4df7dff..3e2d6bb9b8 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -510,3 +510,14 @@ def test_warn_random_found_probability_inference(func, scipy_func, test_value): test_value ), ) + + +def test_icdf_discrete(): + p = 0.1 + value = 0.9 + dist = pm.Geometric.dist(p=p) + dist_icdf = icdf(dist, value) + np.testing.assert_almost_equal( + dist_icdf.eval(), + sp.geom.ppf(value, p), + )