From 7c83518ab3014ebfb61130f9ad452c5ab3eb03bd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Sep 2025 15:27:12 -0400 Subject: [PATCH 1/8] feat: overlay Zygote.gradient and use Enzyme instead --- Project.toml | 3 +++ ext/ReactantZygoteExt.jl | 25 +++++++++++++++++++++++++ test/Project.toml | 2 ++ test/integration/zygote.jl | 14 ++++++++++++++ test/runtests.jl | 1 + 5 files changed, 45 insertions(+) create mode 100644 ext/ReactantZygoteExt.jl create mode 100644 test/integration/zygote.jl diff --git a/Project.toml b/Project.toml index 737ea28018..87bac86d19 100644 --- a/Project.toml +++ b/Project.toml @@ -46,6 +46,7 @@ Random123 = "74087812-796a-5b5d-8853-05524746bad3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [sources] ReactantCore = {path = "lib/ReactantCore"} @@ -67,6 +68,7 @@ ReactantRandom123Ext = "Random123" ReactantSpecialFunctionsExt = "SpecialFunctions" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" +ReactantZygoteExt = "Zygote" [compat] AbstractFFTs = "1.5" @@ -107,6 +109,7 @@ Sockets = "1.10" SpecialFunctions = "2.4" Statistics = "1.10" YaoBlocks = "0.13, 0.14" +Zygote = "0.7" julia = "1.10" unzip_jll = "6" diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl new file mode 100644 index 0000000000..0604ee5c1f --- /dev/null +++ b/ext/ReactantZygoteExt.jl @@ -0,0 +1,25 @@ +module ReactantZygoteExt + +using Reactant: + Reactant, CallWithReactant, @reactant_overlay, use_overlayed_version, call_with_reactant +using Zygote: Zygote +using Enzyme: Enzyme, Reverse, Active, Const, Duplicated + +# TODO: overload the following as well +# - Zygote.pullback +# - Zygote.jacobian +# - Zygote.hessian + +@reactant_overlay function Zygote.gradient(f::F, args...) where {F} + # TODO: check `f` as well once #1642 is merged + if use_overlayed_version(args) + dargs = map(Enzyme.make_zero, args) + duplicated = map(Duplicated, args, dargs) + Reactant.overload_autodiff(Reverse, Const(f), Active, duplicated...) + return dargs + else + return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...) + end +end + +end diff --git a/test/Project.toml b/test/Project.toml index bb9d5f4cfc..a75f670c27 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,6 +31,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Adapt = "4.1" @@ -65,6 +66,7 @@ StableRNGs = "1" Statistics = "1.10" StatsBase = "0.34" Test = "1.10" +Zygote = "0.7" [extras] Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" diff --git a/test/integration/zygote.jl b/test/integration/zygote.jl new file mode 100644 index 0000000000..0a0c4b5ac0 --- /dev/null +++ b/test/integration/zygote.jl @@ -0,0 +1,14 @@ +using Zygote, Reactant, Enzyme, Test + +sumabs2(x) = sum(abs2, x) + +@testset "Zygote" begin + @testset "Zygote.gradient" begin + x = Reactant.to_rarray(rand(Float32, 32, 10)) + + zyg_grad = @jit Zygote.gradient(sumabs2, x) + enz_grad = @jit Enzyme.gradient(Reverse, Const(sumabs2), x) + @test zyg_grad[1] isa Reactant.ConcreteRArray + @test enz_grad[1] ≈ zyg_grad[1] + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d52ebebe90..ec80b802c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,6 +51,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") + @safetestset "Zygote" include("integration/zygote.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From bc56a8a48f7c946c48c351e98b58929a573e33a4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Sep 2025 15:36:58 -0400 Subject: [PATCH 2/8] fix: add a warning message --- ext/ReactantZygoteExt.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl index 0604ee5c1f..30ad3404a8 100644 --- a/ext/ReactantZygoteExt.jl +++ b/ext/ReactantZygoteExt.jl @@ -13,6 +13,9 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated @reactant_overlay function Zygote.gradient(f::F, args...) where {F} # TODO: check `f` as well once #1642 is merged if use_overlayed_version(args) + @warn "Reactant doesn't support using Zygote for computing gradients. Replacing \ + `Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \ + not use `Zygote.gradient` inside `Reactant.@compile`." maxlog = 1 dargs = map(Enzyme.make_zero, args) duplicated = map(Duplicated, args, dargs) Reactant.overload_autodiff(Reverse, Const(f), Active, duplicated...) From 46e1792c85cae6ea975d49dd5eb1793f40704310 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Sep 2025 15:52:57 -0400 Subject: [PATCH 3/8] fix: update text --- ext/ReactantZygoteExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl index 30ad3404a8..2a86bc7c14 100644 --- a/ext/ReactantZygoteExt.jl +++ b/ext/ReactantZygoteExt.jl @@ -15,7 +15,8 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated if use_overlayed_version(args) @warn "Reactant doesn't support using Zygote for computing gradients. Replacing \ `Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \ - not use `Zygote.gradient` inside `Reactant.@compile`." maxlog = 1 + not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \ + `Reactant.@compile`." maxlog = 1 dargs = map(Enzyme.make_zero, args) duplicated = map(Duplicated, args, dargs) Reactant.overload_autodiff(Reverse, Const(f), Active, duplicated...) From 2a27445ec635ddd60ef66f3d1bc86fe9fdaeac99 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Sep 2025 15:53:16 -0400 Subject: [PATCH 4/8] fix: remove maxlog --- ext/ReactantZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl index 2a86bc7c14..5589609a12 100644 --- a/ext/ReactantZygoteExt.jl +++ b/ext/ReactantZygoteExt.jl @@ -16,7 +16,7 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated @warn "Reactant doesn't support using Zygote for computing gradients. Replacing \ `Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \ not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \ - `Reactant.@compile`." maxlog = 1 + `Reactant.@compile`." dargs = map(Enzyme.make_zero, args) duplicated = map(Duplicated, args, dargs) Reactant.overload_autodiff(Reverse, Const(f), Active, duplicated...) From 6dca23f47e2af42f0ee68d3786c4d43ab4a3183a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Sep 2025 15:55:11 -0400 Subject: [PATCH 5/8] fix: add config flag for zygote --- docs/src/api/config.md | 5 +++++ ext/ReactantZygoteExt.jl | 2 +- src/Configuration.jl | 12 ++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/src/api/config.md b/docs/src/api/config.md index a3915c078c..79024b99f7 100644 --- a/docs/src/api/config.md +++ b/docs/src/api/config.md @@ -33,6 +33,11 @@ Reactant.PrecisionConfig Reactant.DotGeneralAlgorithm ``` +### Zygote Overlay + +- `OVERLAY_ZYGOTE_CALLS`: Whether to overlay `Zygote.gradient` calls with `Enzyme.autodiff` + calls. + ## Environment Variables The following environment variables can be used to configure Reactant. diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl index 5589609a12..e0b4e16e47 100644 --- a/ext/ReactantZygoteExt.jl +++ b/ext/ReactantZygoteExt.jl @@ -12,7 +12,7 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated @reactant_overlay function Zygote.gradient(f::F, args...) where {F} # TODO: check `f` as well once #1642 is merged - if use_overlayed_version(args) + if use_overlayed_version(args) && Reactant.OVERLAY_ZYGOTE_CALLS[] @warn "Reactant doesn't support using Zygote for computing gradients. Replacing \ `Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \ not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \ diff --git a/src/Configuration.jl b/src/Configuration.jl index 5b0eaa00af..73aaeb877b 100644 --- a/src/Configuration.jl +++ b/src/Configuration.jl @@ -30,6 +30,11 @@ scope will use the provided values. or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`. - `convolution_precision`: Precision for `stablehlo.convolution`. Can be `nothing`, or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`. + +### Zygote Overlay + + - `overlay_zygote_calls`: Whether to overlay `Zygote.gradient` calls with + `Enzyme.autodiff` calls. Defaults to `true`. """ function with_config( f; @@ -38,6 +43,7 @@ function with_config( convolution_precision=missing, lower_partialsort_to_approx_top_k=missing, fallback_approx_top_k_lowering=missing, + overlay_zygote_calls=missing, ) config_vars = () dot_general_algorithm !== missing && @@ -58,6 +64,9 @@ function with_config( FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering, ) ) + overlay_zygote_calls !== missing && ( + config_vars = (config_vars..., OVERLAY_ZYGOTE_CALLS => overlay_zygote_calls) + ) return ScopedValues.with(f, config_vars...) end @@ -379,3 +388,6 @@ function DotGeneralAlgorithm( return nothing end + +# Overlay Zygote.jl +const OVERLAY_ZYGOTE_CALLS = ScopedValue(true) From be5f2a778f402fc0dffb58e23356af2f1a2c0961 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Sep 2025 16:02:27 -0400 Subject: [PATCH 6/8] test: check for no overlay --- ext/ReactantZygoteExt.jl | 5 +++-- src/Configuration.jl | 5 ++--- test/integration/zygote.jl | 8 ++++++++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl index e0b4e16e47..10cc5a3c66 100644 --- a/ext/ReactantZygoteExt.jl +++ b/ext/ReactantZygoteExt.jl @@ -12,11 +12,12 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated @reactant_overlay function Zygote.gradient(f::F, args...) where {F} # TODO: check `f` as well once #1642 is merged - if use_overlayed_version(args) && Reactant.OVERLAY_ZYGOTE_CALLS[] + if Reactant.OVERLAY_ZYGOTE_CALLS[] && use_overlayed_version(args) @warn "Reactant doesn't support using Zygote for computing gradients. Replacing \ `Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \ not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \ - `Reactant.@compile`." + `Reactant.@compile`. If this behavior is undesirable, set the \ + `overlay_zygote_calls` scoped value via `Reactant.with_config` to `false`." dargs = map(Enzyme.make_zero, args) duplicated = map(Duplicated, args, dargs) Reactant.overload_autodiff(Reverse, Const(f), Active, duplicated...) diff --git a/src/Configuration.jl b/src/Configuration.jl index 73aaeb877b..35149345e9 100644 --- a/src/Configuration.jl +++ b/src/Configuration.jl @@ -64,9 +64,8 @@ function with_config( FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering, ) ) - overlay_zygote_calls !== missing && ( - config_vars = (config_vars..., OVERLAY_ZYGOTE_CALLS => overlay_zygote_calls) - ) + overlay_zygote_calls !== missing && + (config_vars = (config_vars..., OVERLAY_ZYGOTE_CALLS => overlay_zygote_calls)) return ScopedValues.with(f, config_vars...) end diff --git a/test/integration/zygote.jl b/test/integration/zygote.jl index 0a0c4b5ac0..f442cef391 100644 --- a/test/integration/zygote.jl +++ b/test/integration/zygote.jl @@ -10,5 +10,13 @@ sumabs2(x) = sum(abs2, x) enz_grad = @jit Enzyme.gradient(Reverse, Const(sumabs2), x) @test zyg_grad[1] isa Reactant.ConcreteRArray @test enz_grad[1] ≈ zyg_grad[1] + + @testset "Disable Overlay" begin + @test_throws Zygote.CompileError Reactant.with_config(; + overlay_zygote_calls=false + ) do + @jit Zygote.gradient(sumabs2, x) + end + end end end From 66026574ef682fefc6f50face884c7cff78cd505 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Sep 2025 16:08:58 -0400 Subject: [PATCH 7/8] fix: use Enzyme.gradient directly --- ext/ReactantZygoteExt.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl index 10cc5a3c66..d649d23b7a 100644 --- a/ext/ReactantZygoteExt.jl +++ b/ext/ReactantZygoteExt.jl @@ -18,10 +18,7 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \ `Reactant.@compile`. If this behavior is undesirable, set the \ `overlay_zygote_calls` scoped value via `Reactant.with_config` to `false`." - dargs = map(Enzyme.make_zero, args) - duplicated = map(Duplicated, args, dargs) - Reactant.overload_autodiff(Reverse, Const(f), Active, duplicated...) - return dargs + return Enzyme.gradient(Reverse, Const(f), args...) else return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...) end From f8785aa268d4b9533def6b64c559e185c655e1b1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Sep 2025 17:29:02 -0400 Subject: [PATCH 8/8] fix: more strong wording --- ext/ReactantZygoteExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl index d649d23b7a..f50229ea17 100644 --- a/ext/ReactantZygoteExt.jl +++ b/ext/ReactantZygoteExt.jl @@ -17,7 +17,9 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated `Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \ not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \ `Reactant.@compile`. If this behavior is undesirable, set the \ - `overlay_zygote_calls` scoped value via `Reactant.with_config` to `false`." + `overlay_zygote_calls` scoped value via `Reactant.with_config` to \ + `false`.\n\nReactant can remove this switching without any breaking change \ + and hence reliance on this behavior is strongly discouraged." return Enzyme.gradient(Reverse, Const(f), args...) else return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...)