Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ include("projection.jl")

include("config.jl")
include("rules.jl")
include("chunked_rules.jl")
include("rule_definition_tools.jl")
include("ignore_derivatives.jl")

Expand Down
17 changes: 17 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,20 @@ See also: [`frule_via_ad`](@ref), [`RuleConfig`](@ref) and the documentation on
[rule configurations and calling back into AD](@ref config)
"""
function rrule_via_ad end


abstract type ChunkedRuleCapability end
"""
HasChunkedMode

This trait indicates that a `RuleConfig{>:HasChunkedMode}` can perform chunked AD.
"""
struct HasChunkedMode <: ChunkedRuleCapability end

"""
NoChunkedMode

This is the complement to [`HasChunkedMode`](@ref). To avoid ambiguities [`RuleConfig`]s
that do not support chunked AD should be `RuleConfig{>:NoChunkedMode}`.
"""
struct NoChunkedMode <: ChunkedRuleCapability end
25 changes: 24 additions & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,29 @@ function (::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args.
return frule_kwfunc(kws, frule, args...)
end

struct ProductTangent{P}
partials::P
end
Comment on lines +81 to +83
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much is gained from this kind of chunking? I mean storing several partials in some tuple. For reverse mode, nothing, I think. For forward mode, maybe derivatives_using_output has the same benefits.

The big gain seems to be from making a solid matrix to represent many vectors, and e.g. getting matrix multiplication. So I think the initial design ought to make that work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This design does this (I think). partials here can either be a Tuple{Tuple} or an eachrow(Matrix)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I guess a vector of vectors can have the same meaning as an EachCol. Just a different path, by dispatch on ProductTangent{<:EachSlice}.


function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx...), f, args...)
frule((Δf, Δx...), f, args...)
end

function frule(::RuleConfig{>:HasChunkedMode}, (Δf, Δx::ProductTangent), f, args...)
frule((Δf, first(Δx)), args...)[1], ProductTangent(map(Δrow->frule((Δf, Δrow)[2], f, args...), Δx))
end

function rrule(::RuleConfig{>:HasChunkedMode}, f, args...)
y, back = rrule(args...)
return y, ApplyBack(back)
end

struct ApplyBack{F}; back::F; end

(a::ApplyBack)(dy) = a.back(dy)
(a::ApplyBack)(dy::ProductTangent) = ProductTangent(map(a.back, dy.partials)) # or some Tangent recursion?


"""
rrule([::RuleConfig,] f, x...)

Expand Down Expand Up @@ -149,7 +172,7 @@ const NO_RRULE_DOC = """
This is an piece of infastructure supporting opting out of [`rrule`](@ref).
It follows the signature for `rrule` exactly.
A collection of type-tuples is stored in its method-table.
If something has this defined, it means that it must having a must also have a `rrule`,
If something has this defined, it means that it must having a must also have a `rrule`,
defined that returns `nothing`.

!!! warning "Do not overload no_rrule directly"
Expand Down
5 changes: 3 additions & 2 deletions src/tangent_types/thunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ Define a [`Thunk`](@ref) wrapping the `expr`, to lazily defer its evaluation.
macro thunk(body)
# Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined.
# so we get useful stack traces if it errors.
func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
return :(Thunk($(esc(func))))
#func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
#return :(Thunk($(esc(func))))
return esc(body)
end

"""
Expand Down