diff --git a/pyproject.toml b/pyproject.toml index 9e09539..06fb168 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,13 +155,13 @@ dask-core = ">=2025.7.0" # No distributed, tornado, etc. sparse = ">=0.17.0" [tool.pixi.feature.backends.target.linux-64.dependencies] -jax = ">=0.6.0,!=0.6.2" # 0.6.2 segfaults on Linux CUDA +jax = ">=0.7.0,!=0.6.2" # 0.6.2 segfaults on Linux CUDA [tool.pixi.feature.backends.target.osx-64.dependencies] -jax = ">=0.6.0,!=0.6.2" +jax = ">=0.7.0,!=0.6.2" [tool.pixi.feature.backends.target.osx-arm64.dependencies] -jax = ">=0.6.0,!=0.6.2" +jax = ">=0.7.0,!=0.6.2" [tool.pixi.feature.backends.target.win-64.dependencies] # jax = "*" # unavailable @@ -177,7 +177,7 @@ system-requirements = { cuda = "12" } [tool.pixi.feature.cuda-backends.target.linux-64.dependencies] cupy = ">=13.5.1" -jaxlib = { version = ">=0.6.0", build = "cuda12*" } +jaxlib = { version = ">=0.7.0", build = "cuda12*" } pytorch = { version = ">=2.7.1", build = "cuda12*" } [tool.pixi.feature.cuda-backends.target.osx-64.dependencies]