Skip to content

Implement betaincinv and gammainc[c]inv functions #502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
Sub,
)
from pytensor.scalar.math import (
BetaIncInv,
Erf,
Erfc,
Erfcinv,
Erfcx,
Erfinv,
GammaIncCInv,
GammaIncInv,
Iv,
Ive,
Log1mexp,
Expand Down Expand Up @@ -226,6 +229,20 @@ def second(x, y):
return second


@jax_funcify.register(GammaIncInv)
def jax_funcify_GammaIncInv(op, **kwargs):
gammaincinv = try_import_tfp_jax_op(op, jax_op_name="igammainv")

return gammaincinv


@jax_funcify.register(GammaIncCInv)
def jax_funcify_GammaIncCInv(op, **kwargs):
gammainccinv = try_import_tfp_jax_op(op, jax_op_name="igammacinv")

return gammainccinv


@jax_funcify.register(Erf)
def jax_funcify_Erf(op, node, **kwargs):
def erf(x):
Expand All @@ -250,6 +267,7 @@ def erfinv(x):
return erfinv


@jax_funcify.register(BetaIncInv)
@jax_funcify.register(Erfcx)
@jax_funcify.register(Erfcinv)
def jax_funcify_from_tfp(op, **kwargs):
Expand Down
95 changes: 95 additions & 0 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,64 @@ def __hash__(self):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")


class GammaIncInv(BinaryScalarOp):
"""
Inverse to the regularized lower incomplete gamma function.
"""

nfunc_spec = ("scipy.special.gammaincinv", 2, 1)

@staticmethod
def st_impl(k, x):
return scipy.special.gammaincinv(k, x)

def impl(self, k, x):
return GammaIncInv.st_impl(k, x)

def grad(self, inputs, grads):
(k, x) = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, k),
gz * exp(gammaincinv(k, x)) * gamma(k) * (gammaincinv(k, x) ** (1 - k)),
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


gammaincinv = GammaIncInv(upgrade_to_float, name="gammaincinv")


class GammaIncCInv(BinaryScalarOp):
"""
Inverse to the regularized upper incomplete gamma function.
"""

nfunc_spec = ("scipy.special.gammainccinv", 2, 1)

@staticmethod
def st_impl(k, x):
return scipy.special.gammainccinv(k, x)

def impl(self, k, x):
return GammaIncCInv.st_impl(k, x)

def grad(self, inputs, grads):
(k, x) = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, k),
gz * -exp(gammainccinv(k, x)) * gamma(k) * (gammainccinv(k, x) ** (1 - k)),
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


gammainccinv = GammaIncCInv(upgrade_to_float, name="gammainccinv")


def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop):
init = [as_scalar(x) if x is not None else None for x in init]
constant = [as_scalar(x) for x in constant]
Expand Down Expand Up @@ -1648,6 +1706,43 @@ def inner_loop(
return grad


class BetaIncInv(ScalarOp):
"""
Inverse of the regularized incomplete beta function.
"""

nfunc_spec = ("scipy.special.betaincinv", 3, 1)

def impl(self, a, b, x):
return scipy.special.betaincinv(a, b, x)

def grad(self, inputs, grads):
(a, b, x) = inputs
(gz,) = grads
return [
grad_not_implemented(self, 0, a),
grad_not_implemented(self, 0, b),
gz
* exp(betaln(a, b))
* ((1 - betaincinv(a, b, x)) ** (1 - b))
* (betaincinv(a, b, x) ** (1 - a)),
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


betaincinv = BetaIncInv(upgrade_to_float_no_complex, name="betaincinv")


def betaln(a, b):
"""
Beta function from gamma function.
"""

return gammaln(a) + gammaln(b) - gammaln(a + b)


class Hyp2F1(ScalarOp):
"""
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
Expand Down
15 changes: 15 additions & 0 deletions pytensor/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ def gammal_inplace(k, x):
"""lower incomplete gamma function"""


@scalar_elemwise
def gammaincinv_inplace(k, x):
"""Inverse to the regularized lower incomplete gamma function"""


@scalar_elemwise
def gammainccinv_inplace(k, x):
"""Inverse of the regularized upper incomplete gamma function"""


@scalar_elemwise
def j0_inplace(x):
"""Bessel function of the first kind of order 0."""
Expand Down Expand Up @@ -338,6 +348,11 @@ def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def betaincinv_inplace(a, b, x):
"""Inverse of the regularized incomplete beta function"""


@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
Expand Down
18 changes: 18 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,16 @@ def gammal(k, x):
"""Lower incomplete gamma function."""


@scalar_elemwise
def gammaincinv(k, x):
"""Inverse to the regularized lower incomplete gamma function"""


@scalar_elemwise
def gammainccinv(k, x):
"""Inverse of the regularized upper incomplete gamma function"""


@scalar_elemwise
def hyp2f1(a, b, c, z):
"""Gaussian hypergeometric function."""
Expand Down Expand Up @@ -1451,6 +1461,11 @@ def betainc(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def betaincinv(a, b, x):
"""Inverse of the regularized incomplete beta function"""


@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`."""
Expand Down Expand Up @@ -3044,6 +3059,8 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y):
"gammaincc",
"gammau",
"gammal",
"gammaincinv",
"gammainccinv",
"j0",
"j1",
"jv",
Expand All @@ -3057,6 +3074,7 @@ def vectorize_node_to_matmul(op, node, batched_x, batched_y):
"log1pexp",
"log1mexp",
"betainc",
"betaincinv",
"real",
"imag",
"angle",
Expand Down
20 changes: 19 additions & 1 deletion pytensor/tensor/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.graph.basic import Apply
from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import gamma, neg, sum
from pytensor.tensor.math import gamma, gammaln, neg, sum


class SoftmaxGrad(COp):
Expand Down Expand Up @@ -752,9 +752,27 @@ def factorial(n):
return gamma(n + 1)


def beta(a, b):
"""
Beta function.

"""
return (gamma(a) * gamma(b)) / gamma(a + b)


def betaln(a, b):
"""
Log beta function.

"""
return gammaln(a) + gammaln(b) - gammaln(a + b)


__all__ = [
"softmax",
"log_softmax",
"poch",
"factorial",
"beta",
"betaln",
]
35 changes: 35 additions & 0 deletions tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import (
betaincinv,
cosh,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
gammainccinv,
gammaincinv,
iv,
log,
log1mexp,
Expand Down Expand Up @@ -165,6 +168,38 @@ def test_tfp_ops(op, test_values):
compare_jax_and_py(fg, test_values)


def test_betaincinv():
a = vector("a", dtype="float64")
b = vector("b", dtype="float64")
x = vector("x", dtype="float64")
out = betaincinv(a, b, x)
fg = FunctionGraph([a, b, x], [out])
compare_jax_and_py(
fg,
[
np.array([5.5, 7.0]),
np.array([5.5, 7.0]),
np.array([0.25, 0.7]),
],
)


def test_gammaincinv():
k = vector("k", dtype="float64")
x = vector("x", dtype="float64")
out = gammaincinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])])


def test_gammainccinv():
k = vector("k", dtype="float64")
x = vector("x", dtype="float64")
out = gammainccinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])])


def test_psi():
x = scalar("x")
out = psi(x)
Expand Down
Loading