Skip to content

Commit fd603c2

Browse files
authored
feat: overlay Zygote.gradient and use Enzyme instead (#1658)
* feat: overlay Zygote.gradient and use Enzyme instead * fix: add a warning message * fix: update text * fix: remove maxlog * fix: add config flag for zygote * test: check for no overlay * fix: use Enzyme.gradient directly * fix: more strong wording
1 parent 71160f6 commit fd603c2

File tree

7 files changed

+73
-0
lines changed

7 files changed

+73
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4747
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4848
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4949
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
50+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5051

5152
[sources]
5253
ReactantCore = {path = "lib/ReactantCore"}
@@ -69,6 +70,7 @@ ReactantSparseArraysExt = "SparseArrays"
6970
ReactantSpecialFunctionsExt = "SpecialFunctions"
7071
ReactantStatisticsExt = "Statistics"
7172
ReactantYaoBlocksExt = "YaoBlocks"
73+
ReactantZygoteExt = "Zygote"
7274

7375
[compat]
7476
AbstractFFTs = "1.5"
@@ -111,6 +113,7 @@ SparseArrays = "1.10"
111113
SpecialFunctions = "2.4"
112114
Statistics = "1.10"
113115
YaoBlocks = "0.13, 0.14"
116+
Zygote = "0.7"
114117
julia = "1.10"
115118
unzip_jll = "6"
116119

docs/src/api/config.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ Reactant.PrecisionConfig
3333
Reactant.DotGeneralAlgorithm
3434
```
3535

36+
### Zygote Overlay
37+
38+
- `OVERLAY_ZYGOTE_CALLS`: Whether to overlay `Zygote.gradient` calls with `Enzyme.autodiff`
39+
calls.
40+
3641
## Environment Variables
3742

3843
The following environment variables can be used to configure Reactant.

ext/ReactantZygoteExt.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
module ReactantZygoteExt
2+
3+
using Reactant:
4+
Reactant, CallWithReactant, @reactant_overlay, use_overlayed_version, call_with_reactant
5+
using Zygote: Zygote
6+
using Enzyme: Enzyme, Reverse, Active, Const, Duplicated
7+
8+
# TODO: overload the following as well
9+
# - Zygote.pullback
10+
# - Zygote.jacobian
11+
# - Zygote.hessian
12+
13+
@reactant_overlay function Zygote.gradient(f::F, args...) where {F}
14+
# TODO: check `f` as well once #1642 is merged
15+
if Reactant.OVERLAY_ZYGOTE_CALLS[] && use_overlayed_version(args)
16+
@warn "Reactant doesn't support using Zygote for computing gradients. Replacing \
17+
`Zygote.gradient` with `Enzyme.autodiff` call. Please update your code to \
18+
not use `Zygote.gradient` and instead use `Enzyme.gradient` inside \
19+
`Reactant.@compile`. If this behavior is undesirable, set the \
20+
`overlay_zygote_calls` scoped value via `Reactant.with_config` to \
21+
`false`.\n\nReactant can remove this switching without any breaking change \
22+
and hence reliance on this behavior is strongly discouraged."
23+
return Enzyme.gradient(Reverse, Const(f), args...)
24+
else
25+
return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...)
26+
end
27+
end
28+
29+
end

src/Configuration.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ scope will use the provided values.
3030
or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`.
3131
- `convolution_precision`: Precision for `stablehlo.convolution`. Can be `nothing`,
3232
or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`.
33+
34+
### Zygote Overlay
35+
36+
- `overlay_zygote_calls`: Whether to overlay `Zygote.gradient` calls with
37+
`Enzyme.autodiff` calls. Defaults to `true`.
3338
"""
3439
function with_config(
3540
f;
@@ -38,6 +43,7 @@ function with_config(
3843
convolution_precision=missing,
3944
lower_partialsort_to_approx_top_k=missing,
4045
fallback_approx_top_k_lowering=missing,
46+
overlay_zygote_calls=missing,
4147
)
4248
config_vars = ()
4349
dot_general_algorithm !== missing &&
@@ -58,6 +64,8 @@ function with_config(
5864
FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering,
5965
)
6066
)
67+
overlay_zygote_calls !== missing &&
68+
(config_vars = (config_vars..., OVERLAY_ZYGOTE_CALLS => overlay_zygote_calls))
6169

6270
return ScopedValues.with(f, config_vars...)
6371
end
@@ -379,3 +387,6 @@ function DotGeneralAlgorithm(
379387

380388
return nothing
381389
end
390+
391+
# Overlay Zygote.jl
392+
const OVERLAY_ZYGOTE_CALLS = ScopedValue(true)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3434
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3535
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3636
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
37+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3738

3839
[compat]
3940
Adapt = "4.1"
@@ -68,6 +69,7 @@ StableRNGs = "1"
6869
Statistics = "1.10"
6970
StatsBase = "0.34"
7071
Test = "1.10"
72+
Zygote = "0.7"
7173

7274
[extras]
7375
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"

test/integration/zygote.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Zygote, Reactant, Enzyme, Test
2+
3+
sumabs2(x) = sum(abs2, x)
4+
5+
@testset "Zygote" begin
6+
@testset "Zygote.gradient" begin
7+
x = Reactant.to_rarray(rand(Float32, 32, 10))
8+
9+
zyg_grad = @jit Zygote.gradient(sumabs2, x)
10+
enz_grad = @jit Enzyme.gradient(Reverse, Const(sumabs2), x)
11+
@test zyg_grad[1] isa Reactant.ConcreteRArray
12+
@test enz_grad[1] zyg_grad[1]
13+
14+
@testset "Disable Overlay" begin
15+
@test_throws Zygote.CompileError Reactant.with_config(;
16+
overlay_zygote_calls=false
17+
) do
18+
@jit Zygote.gradient(sumabs2, x)
19+
end
20+
end
21+
end
22+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5252
@safetestset "Python" include("integration/python.jl")
5353
@safetestset "Optimisers" include("integration/optimisers.jl")
5454
@safetestset "FillArrays" include("integration/fillarrays.jl")
55+
@safetestset "Zygote" include("integration/zygote.jl")
5556
end
5657

5758
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"

0 commit comments

Comments
 (0)