Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.11.1"
version = "1.11.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
55 changes: 53 additions & 2 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ end
# Since this works like a zero-array in broadcasting, it should also accept a number:
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx))

# Tuple
# Tuple and NamedTuple
function ProjectTo(x::Tuple)
elements = map(ProjectTo, x)
if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
Expand All @@ -296,10 +296,22 @@ function ProjectTo(x::Tuple)
return ProjectTo{Tangent{typeof(x)}}(; elements=elements)
end
end
function ProjectTo(x::NamedTuple)
elements = map(ProjectTo, x)
if Tuple(elements) isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
return ProjectTo{NoTangent}()
else
return ProjectTo{Tangent{typeof(x)}}(; elements...)
end
end

# This method means that projection is re-applied to the contents of a Tangent.
# We're not entirely sure whether this is every necessary; but it should be safe,
# and should often compile away:
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx))
function (project::ProjectTo{<:Tangent{<:Union{Tuple,NamedTuple}}})(dx::Tangent)
return project(backing(dx))
end

function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
len = length(project.elements)
if length(dx) != len
Expand All @@ -310,6 +322,45 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
dy = map((f, x) -> f(x), project.elements, dx)
return project_type(project)(dy...)
end
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple)
dy = _project_namedtuple(backing(project), dx)
return project_type(project)(; dy...)
end

# Diffractor returns not necessarily a named tuple with all keys and of the same order as
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does Diffractor have to do with anything, and why does it return a namedtuple?
It should be a Tangent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It refers to #515 (comment). The Tangents are already unpacked at this stage.

# the projector
# Thus we can't use `map`
function _project_namedtuple(f::NamedTuple{fn,ft}, x::NamedTuple{xn,xt}) where {fn,ft,xn,xt}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this function for?

Can't we just stick the thing into a Tangent{typeof(f), typeof(x)}(x) ?
which should robustly handly non-present keys and keys in different orders.
And if for some reason we can't handle that then add a canonicalize ?

Copy link
Member Author

@devmotion devmotion Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is our custom projection map. Initially, in the first commit I just used map with the named tuple of projectors and named tuple of derivatives, as suggested by @mcabbott. However, map requires that the names of both named tuples are exactly identical, i.e., all derivatives are present and in the same order as the projectors. This function here just maps all existing derivatives and throws a more descriptive error if a derivative is present without corresponding projector.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it could take short namedtuples and route them through Tangent -> canonise -> backing -> map -> Tangent, to re-use more stuff:

julia> using ChainRulesCore

julia> x = (a=1, b=2, c=3); dx = (b=400,);

julia> Tangent{typeof(x)}(; dx...)
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(b = 400,)

julia> ChainRulesCore.canonicalize(ans)
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(a = ZeroTangent(), b = 400, c = ZeroTangent())

julia> ChainRulesCore.backing(ans)
(a = ZeroTangent(), b = 400, c = ZeroTangent())

My slight reservation about all approaches really is whether we can insert enough complication to confuse Diffractor when it wants to take a 3rd derivative or something.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the backing step?
why not
Tangent -> map, which already returns a Tangent?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, like this it doesn't:

julia> tang
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(a = ZeroTangent(), b = 400, c = ZeroTangent())

julia> projs = map(ProjectTo, x);

julia> map((f,x) -> f(x), projs, tang)
3-element Vector{Any}:
    ZeroTangent()
 400.0
    ZeroTangent()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't dug through all the functions closely, recently, but my reservation here is that this seems close to being a second use of canonicalize, just with a different carefully optimised generated implementation. It seems that if ever something breaks one, we'll have to fix both.

Is there a precedent anywhere else here about whether filling in all fields with NoTangent is preferable / not compared to leaving omitted ones out?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would preferably not fill in all fields.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was my understanding as well - and hence I don't think one should use canonicalize here since we don't want to fill all fields.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott Are you OK with merging the PR as is and improving the implementation later, if e.g. there is a clear need for a two argument map?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, do it. I think this is the right behaviour. I do wish it could be shorter but that's not the end of the world. Sorry about dragging this out so long.

if @generated
vals = Any[
if xn[i] in fn
:(getfield(f, $(QuoteNode(xn[i])))(getfield(x, $(QuoteNode(xn[i])))))
else
throw(
ArgumentError(
"named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])",
),
)
end for i in 1:length(xn)
]
:(NamedTuple{$xn}(($(vals...),)))
else
vals = ntuple(Val(length(xn))) do i
name = xn[i]
if name in fn
getfield(f, name)(getfield(x, name))
else
throw(
ArgumentError(
"named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])",
),
)
end
end
NamedTuple{xn}(vals)
end
end

function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
for d in 1:ndims(dx)
if size(dx, d) != get(length(project.elements), d, 1)
Expand Down
36 changes: 36 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,42 @@ struct NoSuperType end
@test ProjectTo((true, [false])) isa ProjectTo{NoTangent}
end

@testset "Base: NamedTuple" begin
pt1 = @inferred(ProjectTo((a=1.0,)))
@test @inferred(pt1((a=1 + im,))) ==
Tangent{NamedTuple{(:a,),Tuple{Float64}}}(; a=1.0)
@test @inferred(pt1(pt1((a=1,)))) == @inferred(pt1(pt1((a=1,)))) # accepts correct Tangent
@test @inferred(pt1(Tangent{Any}(; a=1))) == pt1((a=1,)) # accepts Tangent{Any}
@test @inferred(pt1(NoTangent())) === NoTangent()
@test @inferred(pt1(ZeroTangent())) === ZeroTangent()

@test_throws Exception pt1((a=1, b=2)) # no projector for `b`
@test_throws Exception pt1((b=1,)) # no projector for `b`

# subset is allowed (required for Diffractor)
@test @inferred(pt1(NamedTuple())) === Tangent{NamedTuple{(:a,),Tuple{Float64}}}()

pt3 = @inferred(ProjectTo((a=[1, 2, 3], b=false, c=:gamma))) # partly non-differentiable
@test @inferred(pt3((a=1:3, b=4, c=5))) ==
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
a=[1.0, 2.0, 3.0], b=NoTangent(), c=NoTangent()
)

# different order
@test @inferred(pt3((b=4, a=1:3, c=5))) ==
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
b=NoTangent(), a=[1.0, 2.0, 3.0], c=NoTangent()
)

# only a subset
@test @inferred(pt3((c=5,))) ==
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
c=NoTangent()
)

@test @inferred(ProjectTo((a=true, b=[false]))) isa ProjectTo{NoTangent}
end

@testset "Base: non-diff" begin
@test ProjectTo(:a)(1) == NoTangent()
@test ProjectTo('b')(2) == NoTangent()
Expand Down