Skip to content

Commit 7947bed

Browse files
authored
Merge pull request #363 from JuliaDiff/ox/config
RuleConfigs (include for calling back into AD)
2 parents 7d667c5 + 87fe59e commit 7947bed

File tree

10 files changed

+419
-12
lines changed

10 files changed

+419
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.10.3"
3+
version = "0.10.4"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/Manifest.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.9.44"
16+
version = "0.10.1"
1717

1818
[[Compat]]
1919
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -34,10 +34,10 @@ deps = ["Random", "Serialization", "Sockets"]
3434
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
3535

3636
[[DocStringExtensions]]
37-
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
38-
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
37+
deps = ["LibGit2"]
38+
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
3939
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
40-
version = "0.8.4"
40+
version = "0.8.5"
4141

4242
[[DocThemeIndigo]]
4343
deps = ["Sass"]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ makedocs(
4747
pages=[
4848
"Introduction" => "index.md",
4949
"FAQ" => "FAQ.md",
50+
"Rule configurations and calling back into AD" => "config.md",
5051
"Writing Good Rules" => "writing_good_rules.md",
5152
"Complex Numbers" => "complex.md",
5253
"Deriving Array Rules" => "arrays.md",

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ add!!
3434
ChainRulesCore.is_inplaceable_destination
3535
```
3636

37+
## RuleConfig
38+
```@autodocs
39+
Modules = [ChainRulesCore]
40+
Pages = ["config.jl"]
41+
Private = false
42+
```
43+
3744
## Internal
3845
```@docs
3946
ChainRulesCore.AbstractTangent

docs/src/config.md

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# [Rule configurations and calling back into AD](@id config)
2+
3+
[`RuleConfig`](@ref) is a method for making rules conditionally defined based on the presence of certain features in the AD system.
4+
One key such feature is the ability to perform AD either in forwards or reverse mode or both.
5+
6+
This is done with a trait-like system (not Holy Traits), where the `RuleConfig` has a union of types as its only type-parameter.
7+
Where each type represents a particular special feature of this AD.
8+
To indicate that the AD system has a special property, its `RuleConfig` should be defined as:
9+
```julia
10+
struct MyADRuleConfig <: RuleConfig{Union{Feature1, Feature2}} end
11+
```
12+
And rules that should only be defined when an AD has a particular special property write:
13+
```julia
14+
rrule(::RuleConfig{>:Feature1}, f, args...) = # rrule that should only be define for ADs with `Feature1`
15+
16+
frule(::RuleConfig{>:Union{Feature1,Feature2}}, f, args...) = # frule that should only be define for ADs with both `Feature1` and `Feature2`
17+
```
18+
19+
A prominent use of this is in declaring that the AD system can, or cannot support being called from within the rule definitions.
20+
21+
## Declaring support for calling back into ADs
22+
23+
To declare support or lack of support for forward and reverse-mode, use the two pairs of complementary types.
24+
For reverse mode: [`HasReverseMode`](@ref), [`NoReverseMode`](@ref).
25+
For forwards mode: [`HasForwardsMode`](@ref), [`NoForwardsMode`](@ref).
26+
AD systems that support any calling back into AD should have one from each set.
27+
28+
If an AD `HasReverseMode`, then it must define [`rrule_via_ad`](@ref) for that RuleConfig subtype.
29+
Similarly, if an AD `HasForwardsMode` then it must define [`frule_via_ad`](@ref) for that RuleConfig subtype.
30+
31+
For example:
32+
```julia
33+
struct MyReverseOnlyADRuleConfig <: RuleConfig{Union{HasReverseMode, NoForwardsMode}} end
34+
35+
function ChainRulesCore.rrule_via_ad(::MyReverseOnlyADRuleConfig, f, args...)
36+
...
37+
return y, pullback
38+
end
39+
```
40+
41+
Note that it is not actually required that the same AD is used for forward and reverse.
42+
For example [Nabla.jl](https://github.com/invenia/Nabla.jl/) is a reverse mode AD.
43+
It might declare that it `HasForwardsMode`, and then define a wrapper around [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) in order to provide that capacity.
44+
45+
## Writing rules that call back into AD
46+
47+
To define e.g. rules for higher order functions, it is useful to be able to call back into the AD system to get it to do some work for you.
48+
49+
For example the rule for reverse mode AD for `map` might like to use forward mode AD if one is available.
50+
Particularly for the case where only a single input collection is being mapped over.
51+
In that case we know the most efficient way to compute that sub-program is in forwards, as each call with-in the map only takes a single input.
52+
53+
Note: the following is not the most efficient rule for `map` via forward, but attempts to be clearer for demonstration purposes.
54+
55+
```julia
56+
function rrule(config::RuleConfig{>:HasForwardsMode}, ::typeof(map), f::Function, x::Array{<:Real})
57+
# real code would support functors/closures, but in interest of keeping example short we exclude it:
58+
@assert (fieldcount(typeof(f)) == 0) "Functors/Closures are not supported"
59+
60+
y_and_ẏ = map(x) do xi
61+
frule_via_ad(config, (NoTangent(), one(xi)), f, xi)
62+
end
63+
y = first.(y_and_ẏ)
64+
= last.(y_and_ẏ)
65+
66+
pullback_map(ȳ) = NoTangent(), NoTangent(), ȳ .*
67+
return y, pullback_map
68+
end
69+
```
70+
71+
## Writing rules that depend on other special requirements of the AD.
72+
73+
The `>:HasReverseMode` and `>:HasForwardsMode` are two examples of special properties that a `RuleConfig` could allow.
74+
Others could also exist, but right now they are the only two.
75+
It is likely that in the future such will be provided for e.g. mutation support.
76+
77+
Such a thing would look like:
78+
```julia
79+
struct SupportsMutation end
80+
81+
function rrule(
82+
::RuleConfig{>:SupportsMutatation}, typeof(push!), x::Vector
83+
)
84+
y = push!(x)
85+
86+
function push!_pullback(ȳ)
87+
pop!(x) # undo change to primal incase it is used in another pullback we haven't called yet
88+
pop!(ȳ) # accumulate gradient via mutating ȳ, then return ZeroTangent
89+
return NoTangent(), ZeroTangent()
90+
end
91+
92+
return y, push!_pullback
93+
end
94+
```
95+
and it would be used in the AD e.g. as follows:
96+
```julia
97+
struct EnzymeRuleConfig <: RuleConfig{Union{SupportsMutation, HasReverseMode, NoForwardsMode}}
98+
```
99+
100+
Note: you can only depend on the presence of a feature, not its absence.
101+
This means we may need to define features and their compliments, when one is not the obvious default (as in the fast of [`HasReverseMode`](@ref)/[`NoReverseMode`](@ref) and [`HasForwardsMode`](@ref)/[`NoForwardsMode`](@ref).).
102+
103+
104+
Such special properties generally should only be defines in `ChainRulesCore`.
105+
(Theoretically, they could be defined elsewhere, but the AD and the package containing the rule need to load them, and ChainRulesCore is the place for things like that.)
106+
107+
108+
## Writing rules that are only for your own AD
109+
110+
A special case of the above is writing rules that are defined only for your own AD.
111+
Rules which otherwise would be type-piracy, and would affect other AD systems.
112+
This could be done via making up a special property type and dispatching on it.
113+
But there is no need, as we can dispatch on the `RuleConfig` subtype directly.
114+
115+
For example in order to avoid mutation in nested AD situations, Zygote might want to have a rule for [`add!!`](@ref) that makes it just do `+`.
116+
117+
```julia
118+
struct ZygoteConfig <: RuleConfig{Union{}} end
119+
120+
rrule(::ZygoteConfig, typeof(ChainRulesCore.add!!), a, b) = a+b, Δ->(NoTangent(), Δ, Δ)
121+
```
122+
123+
As an alternative to rules only for one AD, would be to add new special property definitions to ChainRulesCore (as described above) which would capture what makes that AD special.

src/ChainRulesCore.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ using SparseArrays: SparseVector, SparseMatrixCSC
55
using Compat: hasfield
66

77
export frule, rrule # core function
8-
export @non_differentiable, @scalar_rule, @thunk, @not_implemented # definition helper macros
8+
# rule configurations
9+
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
10+
export frule_via_ad, rrule_via_ad
11+
# definition helper macros
12+
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
913
export canonicalize, extern, unthunk # differential operations
1014
export add!! # gradient accumulation operations
1115
# differentials
@@ -23,6 +27,7 @@ include("differentials/notimplemented.jl")
2327
include("differential_arithmetic.jl")
2428
include("accumulation.jl")
2529

30+
include("config.jl")
2631
include("rules.jl")
2732
include("rule_definition_tools.jl")
2833

src/config.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
RuleConfig{T}
3+
4+
The configuration for what rules to use.
5+
`T`: **traits**. This should be a `Union` of all special traits needed for rules to be
6+
allowed to be defined for your AD. If nothing special this should be set to `Union{}`.
7+
8+
**AD authors** should define a subtype of `RuleConfig` to use when calling `frule`/`rrule`.
9+
10+
**Rule authors** can dispatch on this config when defining rules.
11+
For example:
12+
```julia
13+
# only define rrule for `pop!` on AD systems where mutation is supported.
14+
rrule(::RuleConfig{>:SupportsMutation}, typeof(pop!), ::Vector) = ...
15+
16+
# this definition of map is for any AD that defines a forwards mode
17+
rrule(conf::RuleConfig{>:HasForwardsMode}, typeof(map), ::Vector) = ...
18+
19+
# this definition of map is for any AD that only defines a reverse mode.
20+
# It is not as good as the rrule that can be used if the AD defines a forward-mode as well.
21+
rrule(conf::RuleConfig{>:Union{NoForwardsMode, HasReverseMode}}, typeof(map), ::Vector) = ...
22+
```
23+
24+
For more details see [rule configurations and calling back into AD](@ref config).
25+
"""
26+
abstract type RuleConfig{T} end
27+
28+
# Broadcast like a scalar
29+
Base.Broadcast.broadcastable(config::RuleConfig) = Ref(config)
30+
31+
abstract type ReverseModeCapability end
32+
33+
"""
34+
HasReverseMode
35+
36+
This trait indicates that a `RuleConfig{>:HasReverseMode}` can perform reverse mode AD.
37+
If it is set then [`rrule_via_ad`](@ref) must be implemented.
38+
"""
39+
struct HasReverseMode <: ReverseModeCapability end
40+
41+
"""
42+
NoReverseMode
43+
44+
This is the complement to [`HasReverseMode`](@ref). To avoid ambiguities [`RuleConfig`]s
45+
that do not support performing reverse mode AD should be `RuleConfig{>:NoReverseMode}`.
46+
"""
47+
struct NoReverseMode <: ReverseModeCapability end
48+
49+
abstract type ForwardsModeCapability end
50+
51+
"""
52+
HasForwardsMode
53+
54+
This trait indicates that a `RuleConfig{>:HasForwardsMode}` can perform forward mode AD.
55+
If it is set then [`frule_via_ad`](@ref) must be implemented.
56+
"""
57+
struct HasForwardsMode <: ForwardsModeCapability end
58+
59+
"""
60+
NoForwardsMode
61+
62+
This is the complement to [`HasForwardsMode`](@ref). To avoid ambiguities [`RuleConfig`]s
63+
that do not support performing forwards mode AD should be `RuleConfig{>:NoForwardsMode}`.
64+
"""
65+
struct NoForwardsMode <: ForwardsModeCapability end
66+
67+
68+
"""
69+
frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...)
70+
71+
This function has the same API as [`frule`](@ref), but operates via performing forwards mode
72+
automatic differentiation.
73+
Any `RuleConfig` subtype that supports the [`HasForwardsMode`](@ref) special feature must
74+
provide an implementation of it.
75+
76+
See also: [`rrule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on
77+
[rule configurations and calling back into AD](@ref config)
78+
"""
79+
function frule_via_ad end
80+
81+
"""
82+
rrule_via_ad(::RuleConfig{>:HasReverseMode}, f, args...; kwargs...)
83+
84+
This function has the same API as [`rrule`](@ref), but operates via performing reverse mode
85+
automatic differentiation.
86+
Any `RuleConfig` subtype that supports the [`HasReverseMode`](@ref) special feature must
87+
provide an implementation of it.
88+
89+
See also: [`frule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on
90+
[rule configurations and calling back into AD](@ref config)
91+
"""
92+
function rrule_via_ad end

src/rules.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
frule((Δf, Δx...), f, x...)
2+
frule([::RuleConfig,] (Δf, Δx...), f, x...)
33
44
Expressing the output of `f(x...)` as `Ω`, return the tuple:
55
@@ -50,15 +50,21 @@ So this is actually a [`Tangent`](@ref):
5050
```jldoctest frule
5151
julia> Δsincosx
5252
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)
53-
```.
53+
```
5454
55+
The optional [`RuleConfig`](@ref) option allows specifying frules only for AD systems that
56+
support given features. If not needed, then it can be omitted and the `frule` without it
57+
will be hit as a fallback. This is the case for most rules.
5558
56-
See also: [`rrule`](@ref), [`@scalar_rule`](@ref)
59+
See also: [`rrule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref)
5760
"""
58-
frule(::Any, ::Vararg{Any}; kwargs...) = nothing
61+
frule(::Any, ::Any, ::Vararg{Any}; kwargs...) = nothing
62+
63+
# if no config is present then fallback to config-less rules
64+
frule(::RuleConfig, ȧrgs, f, args...; kwargs...) = frule(ȧrgs, f, args...; kwargs...)
5965

6066
"""
61-
rrule(f, x...)
67+
rrule([::RuleConfig,] f, x...)
6268
6369
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
6470
as `Ω`, return the tuple:
@@ -101,10 +107,19 @@ julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
101107
true
102108
```
103109
104-
See also: [`frule`](@ref), [`@scalar_rule`](@ref)
110+
The optional [`RuleConfig`](@ref) option allows specifying rrules only for AD systems that
111+
support given features. If not needed, then it can be omitted and the `rrule` without it
112+
will be hit as a fallback. This is the case for most rules.
113+
114+
See also: [`frule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref)
105115
"""
106116
rrule(::Any, ::Vararg{Any}) = nothing
107117

118+
# if no config is present then fallback to config-less rules
119+
rrule(::RuleConfig, f, args...; kwargs...) = rrule(f, args...; kwargs...)
120+
# TODO do we need to do something for kwargs special here for performance?
121+
# See: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/368
122+
108123
# Manual fallback for keyword arguments. Usually this would be generated by
109124
#
110125
# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing

0 commit comments

Comments
 (0)