|
| 1 | +# [Rule configurations and calling back into AD](@id config) |
| 2 | + |
| 3 | +[`RuleConfig`](@ref) is a method for making rules conditionally defined based on the presence of certain features in the AD system. |
| 4 | +One key such feature is the ability to perform AD either in forwards or reverse mode or both. |
| 5 | + |
| 6 | +This is done with a trait-like system (not Holy Traits), where the `RuleConfig` has a union of types as its only type-parameter. |
| 7 | +Where each type represents a particular special feature of this AD. |
| 8 | +To indicate that the AD system has a special property, its `RuleConfig` should be defined as: |
| 9 | +```julia |
| 10 | +struct MyADRuleConfig <: RuleConfig{Union{Feature1, Feature2}} end |
| 11 | +``` |
| 12 | +And rules that should only be defined when an AD has a particular special property write: |
| 13 | +```julia |
| 14 | +rrule(::RuleConfig{>:Feature1}, f, args...) = # rrule that should only be define for ADs with `Feature1` |
| 15 | + |
| 16 | +frule(::RuleConfig{>:Union{Feature1,Feature2}}, f, args...) = # frule that should only be define for ADs with both `Feature1` and `Feature2` |
| 17 | +``` |
| 18 | + |
| 19 | +A prominent use of this is in declaring that the AD system can, or cannot support being called from within the rule definitions. |
| 20 | + |
| 21 | +## Declaring support for calling back into ADs |
| 22 | + |
| 23 | +To declare support or lack of support for forward and reverse-mode, use the two pairs of complementary types. |
| 24 | +For reverse mode: [`HasReverseMode`](@ref), [`NoReverseMode`](@ref). |
| 25 | +For forwards mode: [`HasForwardsMode`](@ref), [`NoForwardsMode`](@ref). |
| 26 | +AD systems that support any calling back into AD should have one from each set. |
| 27 | + |
| 28 | +If an AD `HasReverseMode`, then it must define [`rrule_via_ad`](@ref) for that RuleConfig subtype. |
| 29 | +Similarly, if an AD `HasForwardsMode` then it must define [`frule_via_ad`](@ref) for that RuleConfig subtype. |
| 30 | + |
| 31 | +For example: |
| 32 | +```julia |
| 33 | +struct MyReverseOnlyADRuleConfig <: RuleConfig{Union{HasReverseMode, NoForwardsMode}} end |
| 34 | + |
| 35 | +function ChainRulesCore.rrule_via_ad(::MyReverseOnlyADRuleConfig, f, args...) |
| 36 | + ... |
| 37 | + return y, pullback |
| 38 | +end |
| 39 | +``` |
| 40 | + |
| 41 | +Note that it is not actually required that the same AD is used for forward and reverse. |
| 42 | +For example [Nabla.jl](https://github.com/invenia/Nabla.jl/) is a reverse mode AD. |
| 43 | +It might declare that it `HasForwardsMode`, and then define a wrapper around [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) in order to provide that capacity. |
| 44 | + |
| 45 | +## Writing rules that call back into AD |
| 46 | + |
| 47 | +To define e.g. rules for higher order functions, it is useful to be able to call back into the AD system to get it to do some work for you. |
| 48 | + |
| 49 | +For example the rule for reverse mode AD for `map` might like to use forward mode AD if one is available. |
| 50 | +Particularly for the case where only a single input collection is being mapped over. |
| 51 | +In that case we know the most efficient way to compute that sub-program is in forwards, as each call with-in the map only takes a single input. |
| 52 | + |
| 53 | +Note: the following is not the most efficient rule for `map` via forward, but attempts to be clearer for demonstration purposes. |
| 54 | + |
| 55 | +```julia |
| 56 | +function rrule(config::RuleConfig{>:HasForwardsMode}, ::typeof(map), f::Function, x::Array{<:Real}) |
| 57 | + # real code would support functors/closures, but in interest of keeping example short we exclude it: |
| 58 | + @assert (fieldcount(typeof(f)) == 0) "Functors/Closures are not supported" |
| 59 | + |
| 60 | + y_and_ẏ = map(x) do xi |
| 61 | + frule_via_ad(config, (NoTangent(), one(xi)), f, xi) |
| 62 | + end |
| 63 | + y = first.(y_and_ẏ) |
| 64 | + ẏ = last.(y_and_ẏ) |
| 65 | + |
| 66 | + pullback_map(ȳ) = NoTangent(), NoTangent(), ȳ .* ẏ |
| 67 | + return y, pullback_map |
| 68 | +end |
| 69 | +``` |
| 70 | + |
| 71 | +## Writing rules that depend on other special requirements of the AD. |
| 72 | + |
| 73 | +The `>:HasReverseMode` and `>:HasForwardsMode` are two examples of special properties that a `RuleConfig` could allow. |
| 74 | +Others could also exist, but right now they are the only two. |
| 75 | +It is likely that in the future such will be provided for e.g. mutation support. |
| 76 | + |
| 77 | +Such a thing would look like: |
| 78 | +```julia |
| 79 | +struct SupportsMutation end |
| 80 | + |
| 81 | +function rrule( |
| 82 | + ::RuleConfig{>:SupportsMutatation}, typeof(push!), x::Vector |
| 83 | +) |
| 84 | + y = push!(x) |
| 85 | + |
| 86 | + function push!_pullback(ȳ) |
| 87 | + pop!(x) # undo change to primal incase it is used in another pullback we haven't called yet |
| 88 | + pop!(ȳ) # accumulate gradient via mutating ȳ, then return ZeroTangent |
| 89 | + return NoTangent(), ZeroTangent() |
| 90 | + end |
| 91 | + |
| 92 | + return y, push!_pullback |
| 93 | +end |
| 94 | +``` |
| 95 | +and it would be used in the AD e.g. as follows: |
| 96 | +```julia |
| 97 | +struct EnzymeRuleConfig <: RuleConfig{Union{SupportsMutation, HasReverseMode, NoForwardsMode}} |
| 98 | +``` |
| 99 | + |
| 100 | +Note: you can only depend on the presence of a feature, not its absence. |
| 101 | +This means we may need to define features and their compliments, when one is not the obvious default (as in the fast of [`HasReverseMode`](@ref)/[`NoReverseMode`](@ref) and [`HasForwardsMode`](@ref)/[`NoForwardsMode`](@ref).). |
| 102 | + |
| 103 | + |
| 104 | +Such special properties generally should only be defines in `ChainRulesCore`. |
| 105 | +(Theoretically, they could be defined elsewhere, but the AD and the package containing the rule need to load them, and ChainRulesCore is the place for things like that.) |
| 106 | + |
| 107 | + |
| 108 | +## Writing rules that are only for your own AD |
| 109 | + |
| 110 | +A special case of the above is writing rules that are defined only for your own AD. |
| 111 | +Rules which otherwise would be type-piracy, and would affect other AD systems. |
| 112 | +This could be done via making up a special property type and dispatching on it. |
| 113 | +But there is no need, as we can dispatch on the `RuleConfig` subtype directly. |
| 114 | + |
| 115 | +For example in order to avoid mutation in nested AD situations, Zygote might want to have a rule for [`add!!`](@ref) that makes it just do `+`. |
| 116 | + |
| 117 | +```julia |
| 118 | +struct ZygoteConfig <: RuleConfig{Union{}} end |
| 119 | + |
| 120 | +rrule(::ZygoteConfig, typeof(ChainRulesCore.add!!), a, b) = a+b, Δ->(NoTangent(), Δ, Δ) |
| 121 | +``` |
| 122 | + |
| 123 | +As an alternative to rules only for one AD, would be to add new special property definitions to ChainRulesCore (as described above) which would capture what makes that AD special. |
0 commit comments