Skip to content

Conversation

@VARUN3WARE
Copy link

Summary

This PR implements the inverse CDF (percent point function) for Gamma and Beta distributions, as requested in issue #20358.

Changes

  • jax.scipy.special.gammaincinv: Inverse of the regularized lower incomplete gamma function using Newton-Halley hybrid method with Wilson-Hilferty transformation for initial guesses
  • jax.scipy.special.betaincinv: Inverse of the regularized incomplete beta function using Newton's method with adaptive step sizing
  • jax.scipy.stats.gamma.ppf: Percent point function (inverse CDF) for the Gamma distribution
  • jax.scipy.stats.beta.ppf: Percent point function (inverse CDF) for the Beta distribution
  • Custom JVP rules for all inverse functions to enable automatic differentiation
  • Comprehensive test suite covering:
    • Inversion properties (forward→inverse→forward consistency)
    • Edge cases (boundary values, invalid inputs)
    • Gradient correctness and finiteness

Use Cases

This enables:

  • Quantile calculations for Gamma and Beta distributions
  • Reparameterization trick in variational inference
  • Statistical analysis requiring inverse CDF operations
  • Compatibility with SciPy's scipy.stats.gamma.ppf and scipy.stats.beta.ppf

Testing

All tests pass with accuracy matching SciPy within tolerances:

  • gammaincinv: typically 1e-6 accuracy
  • betaincinv: typically 1e-4 accuracy
  • Gradients verified to be positive and finite

Fixes #20358

- Implement jax.scipy.special.gammaincinv using Newton-Halley method
- Implement jax.scipy.special.betaincinv using Newton's method
- Implement jax.scipy.stats.gamma.ppf
- Implement jax.scipy.stats.beta.ppf
- Add custom JVP rules for differentiability
- Add comprehensive tests with gradient checks
- Enable reparameterization trick for variational inference
@google-cla
Copy link

google-cla bot commented Nov 28, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @VARUN3WARE, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the percent point function (PPF), also known as the inverse cumulative distribution function (inverse CDF), for both Gamma and Beta distributions within JAX's SciPy module. This enhancement provides critical functionality for quantile calculations, facilitates the reparameterization trick in variational inference, and supports various statistical analyses that rely on inverse CDF operations, all while maintaining compatibility with SciPy's existing ppf functions and supporting automatic differentiation.

Highlights

  • Inverse CDF for Gamma Distribution: Implemented jax.scipy.special.gammaincinv and jax.scipy.stats.gamma.ppf for the Gamma distribution, enabling quantile calculations.
  • Inverse CDF for Beta Distribution: Implemented jax.scipy.special.betaincinv and jax.scipy.stats.beta.ppf for the Beta distribution, providing its inverse CDF functionality.
  • Automatic Differentiation Support: Custom JVP rules were added for all new inverse functions (gammaincinv and betaincinv) to ensure compatibility with JAX's automatic differentiation system.
  • Robust Numerical Methods: gammaincinv utilizes a Newton-Halley hybrid method with Wilson-Hilferty transformation for initial guesses, while betaincinv employs Newton's method with adaptive step sizing for robust convergence.
  • Comprehensive Testing: Extensive test suites were added to verify inversion properties, handle various edge cases, and confirm the correctness and finiteness of gradients, matching SciPy's accuracy.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the inverse CDF (percent point function) for Gamma and Beta distributions, which is a valuable addition to JAX's statistical functions. The implementations for gammaincinv and betaincinv use appropriate numerical methods, and the accompanying tests are thorough. My main feedback concerns the custom JVP rules, which are incomplete and could lead to silently incorrect gradients. I've suggested changes to make these limitations explicit. I also found a few minor areas for improvement in comments and code clarity.

Comment on lines +418 to +420
# For now, only implement the y derivative (most common use case)
# Derivatives w.r.t. a and b are more complex
x_dot = y_dot * dxdy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The JVP rule for betaincinv is incomplete. It only computes the derivative with respect to y and silently ignores the contributions from a and b. This will lead to incorrect gradients if betaincinv is differentiated with respect to a or b. To prevent silent errors, please either implement the derivatives for a and b or add a check to raise a NotImplementedError if non-zero tangents are provided for these parameters. The custom gradient tests for this function also seem to only cover the derivative with respect to y.

Suggested change
# For now, only implement the y derivative (most common use case)
# Derivatives w.r.t. a and b are more complex
x_dot = y_dot * dxdy
from jax._src.ad_util import Zero
# For now, only implement the y derivative (most common use case)
# Derivatives w.r.t. a and b are more complex
if not isinstance(a_dot, Zero):
raise NotImplementedError("The JVP of betaincinv with respect to `a` is not implemented.")
if not isinstance(b_dot, Zero):
raise NotImplementedError("The JVP of betaincinv with respect to `b` is not implemented.")
x_dot = y_dot * dxdy

Comment on lines +645 to +646
# For now, only implement the y derivative (most common use case)
x_dot = y_dot * dxdy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to betaincinv, the JVP rule for gammaincinv is incomplete. It only computes the derivative with respect to y and silently ignores the contribution from a. This will produce incorrect gradients when differentiating with respect to a. Please either implement the derivative for a or raise a NotImplementedError if a non-zero tangent is provided. The custom gradient tests for this function also seem to only cover the derivative with respect to y.

Suggested change
# For now, only implement the y derivative (most common use case)
x_dot = y_dot * dxdy
from jax._src.ad_util import Zero
if not isinstance(a_dot, Zero):
raise NotImplementedError("The JVP of gammaincinv with respect to `a` is not implemented.")
x_dot = y_dot * dxdy

- :func:`jax.scipy.special.betaln`
Notes:
This function uses a Newton-Halley hybrid iterative method to find the inverse.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring states that this function uses a "Newton-Halley hybrid iterative method", but the implementation comment at line 357 and the code itself indicate that Newton's method is used for stability. Please update the docstring to accurately reflect the implementation.

Suggested change
This function uses a Newton-Halley hybrid iterative method to find the inverse.
This function uses Newton's method with adaptive step sizing to find the inverse.

Comment on lines +321 to +332
def mode_guess(a, b):
"""Use the mode when both a,b > 1."""
return (a - one) / (a + b - dtype(2))

def quantile_approx_guess(a, b, y):
"""Approximation based on normal distribution for a,b large."""
# When a, b are large, beta(a,b) ~ Normal(mean, variance)
mean = a / (a + b)
var = (a * b) / ((a + b) ** 2 * (a + b + one))
# Use inverse normal
z = ndtri(y)
return mean + lax.sqrt(var) * z
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper functions mode_guess and quantile_approx_guess are defined but do not appear to be used in the betaincinv implementation. To improve code clarity and maintainability, please remove this dead code.

# => (d/da gammainc(a,x))|_x + (d/dx gammainc(a,x)) * dx/da = 0
# => dx/da = -(d/da gammainc(a,x)|_x) / (d/dx gammainc(a,x))
#
# d/da gammainc(a,x)|_x is complex, so for now we only support y_dot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment d/da gammainc(a,x)|_x is complex is misleading. For real a and x, this derivative is real-valued, although it is complicated to compute (it involves the digamma function). Please update the comment for accuracy.

Suggested change
# d/da gammainc(a,x)|_x is complex, so for now we only support y_dot
# d/da gammainc(a,x)|_x is complicated to compute, so for now we only support y_dot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

inverse CDF for Gamma/Beta distributions

1 participant