-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Add ppf (inverse CDF) for Gamma and Beta distributions (#20358) #33589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Implement jax.scipy.special.gammaincinv using Newton-Halley method - Implement jax.scipy.special.betaincinv using Newton's method - Implement jax.scipy.stats.gamma.ppf - Implement jax.scipy.stats.beta.ppf - Add custom JVP rules for differentiability - Add comprehensive tests with gradient checks - Enable reparameterization trick for variational inference
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @VARUN3WARE, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the percent point function (PPF), also known as the inverse cumulative distribution function (inverse CDF), for both Gamma and Beta distributions within JAX's SciPy module. This enhancement provides critical functionality for quantile calculations, facilitates the reparameterization trick in variational inference, and supports various statistical analyses that rely on inverse CDF operations, all while maintaining compatibility with SciPy's existing Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the inverse CDF (percent point function) for Gamma and Beta distributions, which is a valuable addition to JAX's statistical functions. The implementations for gammaincinv and betaincinv use appropriate numerical methods, and the accompanying tests are thorough. My main feedback concerns the custom JVP rules, which are incomplete and could lead to silently incorrect gradients. I've suggested changes to make these limitations explicit. I also found a few minor areas for improvement in comments and code clarity.
| # For now, only implement the y derivative (most common use case) | ||
| # Derivatives w.r.t. a and b are more complex | ||
| x_dot = y_dot * dxdy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JVP rule for betaincinv is incomplete. It only computes the derivative with respect to y and silently ignores the contributions from a and b. This will lead to incorrect gradients if betaincinv is differentiated with respect to a or b. To prevent silent errors, please either implement the derivatives for a and b or add a check to raise a NotImplementedError if non-zero tangents are provided for these parameters. The custom gradient tests for this function also seem to only cover the derivative with respect to y.
| # For now, only implement the y derivative (most common use case) | |
| # Derivatives w.r.t. a and b are more complex | |
| x_dot = y_dot * dxdy | |
| from jax._src.ad_util import Zero | |
| # For now, only implement the y derivative (most common use case) | |
| # Derivatives w.r.t. a and b are more complex | |
| if not isinstance(a_dot, Zero): | |
| raise NotImplementedError("The JVP of betaincinv with respect to `a` is not implemented.") | |
| if not isinstance(b_dot, Zero): | |
| raise NotImplementedError("The JVP of betaincinv with respect to `b` is not implemented.") | |
| x_dot = y_dot * dxdy |
| # For now, only implement the y derivative (most common use case) | ||
| x_dot = y_dot * dxdy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to betaincinv, the JVP rule for gammaincinv is incomplete. It only computes the derivative with respect to y and silently ignores the contribution from a. This will produce incorrect gradients when differentiating with respect to a. Please either implement the derivative for a or raise a NotImplementedError if a non-zero tangent is provided. The custom gradient tests for this function also seem to only cover the derivative with respect to y.
| # For now, only implement the y derivative (most common use case) | |
| x_dot = y_dot * dxdy | |
| from jax._src.ad_util import Zero | |
| if not isinstance(a_dot, Zero): | |
| raise NotImplementedError("The JVP of gammaincinv with respect to `a` is not implemented.") | |
| x_dot = y_dot * dxdy |
| - :func:`jax.scipy.special.betaln` | ||
| Notes: | ||
| This function uses a Newton-Halley hybrid iterative method to find the inverse. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring states that this function uses a "Newton-Halley hybrid iterative method", but the implementation comment at line 357 and the code itself indicate that Newton's method is used for stability. Please update the docstring to accurately reflect the implementation.
| This function uses a Newton-Halley hybrid iterative method to find the inverse. | |
| This function uses Newton's method with adaptive step sizing to find the inverse. |
| def mode_guess(a, b): | ||
| """Use the mode when both a,b > 1.""" | ||
| return (a - one) / (a + b - dtype(2)) | ||
|
|
||
| def quantile_approx_guess(a, b, y): | ||
| """Approximation based on normal distribution for a,b large.""" | ||
| # When a, b are large, beta(a,b) ~ Normal(mean, variance) | ||
| mean = a / (a + b) | ||
| var = (a * b) / ((a + b) ** 2 * (a + b + one)) | ||
| # Use inverse normal | ||
| z = ndtri(y) | ||
| return mean + lax.sqrt(var) * z |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # => (d/da gammainc(a,x))|_x + (d/dx gammainc(a,x)) * dx/da = 0 | ||
| # => dx/da = -(d/da gammainc(a,x)|_x) / (d/dx gammainc(a,x)) | ||
| # | ||
| # d/da gammainc(a,x)|_x is complex, so for now we only support y_dot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment d/da gammainc(a,x)|_x is complex is misleading. For real a and x, this derivative is real-valued, although it is complicated to compute (it involves the digamma function). Please update the comment for accuracy.
| # d/da gammainc(a,x)|_x is complex, so for now we only support y_dot | |
| # d/da gammainc(a,x)|_x is complicated to compute, so for now we only support y_dot |
Summary
This PR implements the inverse CDF (percent point function) for Gamma and Beta distributions, as requested in issue #20358.
Changes
jax.scipy.special.gammaincinv: Inverse of the regularized lower incomplete gamma function using Newton-Halley hybrid method with Wilson-Hilferty transformation for initial guessesjax.scipy.special.betaincinv: Inverse of the regularized incomplete beta function using Newton's method with adaptive step sizingjax.scipy.stats.gamma.ppf: Percent point function (inverse CDF) for the Gamma distributionjax.scipy.stats.beta.ppf: Percent point function (inverse CDF) for the Beta distributionUse Cases
This enables:
scipy.stats.gamma.ppfandscipy.stats.beta.ppfTesting
All tests pass with accuracy matching SciPy within tolerances:
gammaincinv: typically 1e-6 accuracybetaincinv: typically 1e-4 accuracyFixes #20358