Skip to content

Fix bug in JAX cloning of RNG shared variables #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 24, 2023

Conversation

ricardoV94
Copy link
Member

Closes #314

@ricardoV94 ricardoV94 added bug Something isn't working jax tests labels May 23, 2023
@ricardoV94 ricardoV94 changed the title Fix jax jit rng bug Fix bug in JAX cloning of RNG shared variables May 23, 2023
@ricardoV94 ricardoV94 force-pushed the fix_jax_jit_rng_bug branch 2 times, most recently from ac8f1f2 to 02a0946 Compare May 23, 2023 08:24
@ricardoV94 ricardoV94 force-pushed the fix_jax_jit_rng_bug branch from 02a0946 to 6ce3527 Compare May 23, 2023 08:35
@maresb
Copy link
Contributor

maresb commented May 23, 2023

@ricardoV94
Copy link
Member Author

Yes this is failing because of JAX warnings when running on float32

@ricardoV94 ricardoV94 force-pushed the fix_jax_jit_rng_bug branch 2 times, most recently from bdc8a2d to 9f2e09d Compare May 23, 2023 10:10
@ricardoV94
Copy link
Member Author

The tests failing are simply due to the new numba, they should pass once #317 goes through

@ricardoV94 ricardoV94 force-pushed the fix_jax_jit_rng_bug branch from 9f2e09d to f370938 Compare May 24, 2023 09:15
@ricardoV94 ricardoV94 merged commit 53b00ea into pymc-devs:main May 24, 2023
@ricardoV94 ricardoV94 deleted the fix_jax_jit_rng_bug branch May 25, 2023 08:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jax tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Function with only updates after givens fails in JAX mode
3 participants