-
Notifications
You must be signed in to change notification settings - Fork 64
Add ProjectTo(::NamedTuple)
#515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -287,7 +287,7 @@ end | |
| # Since this works like a zero-array in broadcasting, it should also accept a number: | ||
| (project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx)) | ||
|
|
||
| # Tuple | ||
| # Tuple and NamedTuple | ||
| function ProjectTo(x::Tuple) | ||
| elements = map(ProjectTo, x) | ||
| if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}} | ||
|
|
@@ -296,10 +296,22 @@ function ProjectTo(x::Tuple) | |
| return ProjectTo{Tangent{typeof(x)}}(; elements=elements) | ||
| end | ||
| end | ||
| function ProjectTo(x::NamedTuple) | ||
| elements = map(ProjectTo, x) | ||
| if Tuple(elements) isa NTuple{<:Any,ProjectTo{<:AbstractZero}} | ||
| return ProjectTo{NoTangent}() | ||
| else | ||
| return ProjectTo{Tangent{typeof(x)}}(; elements...) | ||
| end | ||
| end | ||
|
|
||
| # This method means that projection is re-applied to the contents of a Tangent. | ||
| # We're not entirely sure whether this is every necessary; but it should be safe, | ||
| # and should often compile away: | ||
| (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx)) | ||
| function (project::ProjectTo{<:Tangent{<:Union{Tuple,NamedTuple}}})(dx::Tangent) | ||
| return project(backing(dx)) | ||
| end | ||
|
|
||
| function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) | ||
| len = length(project.elements) | ||
| if length(dx) != len | ||
|
|
@@ -310,6 +322,45 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) | |
| dy = map((f, x) -> f(x), project.elements, dx) | ||
| return project_type(project)(dy...) | ||
| end | ||
| function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple) | ||
| dy = _project_namedtuple(backing(project), dx) | ||
| return project_type(project)(; dy...) | ||
| end | ||
|
|
||
| # Diffractor returns not necessarily a named tuple with all keys and of the same order as | ||
| # the projector | ||
| # Thus we can't use `map` | ||
| function _project_namedtuple(f::NamedTuple{fn,ft}, x::NamedTuple{xn,xt}) where {fn,ft,xn,xt} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this function for? Can't we just stick the thing into a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is our custom projection
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it could take short namedtuples and route them through Tangent -> canonise -> backing -> map -> Tangent, to re-use more stuff: My slight reservation about all approaches really is whether we can insert enough complication to confuse Diffractor when it wants to take a 3rd derivative or something.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, like this it doesn't:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't dug through all the functions closely, recently, but my reservation here is that this seems close to being a second use of Is there a precedent anywhere else here about whether filling in all fields with NoTangent is preferable / not compared to leaving omitted ones out?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would preferably not fill in all fields.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was my understanding as well - and hence I don't think one should use
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mcabbott Are you OK with merging the PR as is and improving the implementation later, if e.g. there is a clear need for a two argument
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, do it. I think this is the right behaviour. I do wish it could be shorter but that's not the end of the world. Sorry about dragging this out so long. |
||
| if @generated | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| vals = Any[ | ||
| if xn[i] in fn | ||
| :(getfield(f, $(QuoteNode(xn[i])))(getfield(x, $(QuoteNode(xn[i]))))) | ||
| else | ||
| throw( | ||
| ArgumentError( | ||
| "named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])", | ||
| ), | ||
| ) | ||
| end for i in 1:length(xn) | ||
| ] | ||
| :(NamedTuple{$xn}(($(vals...),))) | ||
| else | ||
| vals = ntuple(Val(length(xn))) do i | ||
| name = xn[i] | ||
| if name in fn | ||
| getfield(f, name)(getfield(x, name)) | ||
| else | ||
| throw( | ||
| ArgumentError( | ||
| "named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])", | ||
| ), | ||
| ) | ||
| end | ||
| end | ||
| NamedTuple{xn}(vals) | ||
| end | ||
| end | ||
|
|
||
| function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) | ||
| for d in 1:ndims(dx) | ||
| if size(dx, d) != get(length(project.elements), d, 1) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does Diffractor have to do with anything, and why does it return a namedtuple?
It should be a
Tangent.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It refers to #515 (comment). The
Tangents are already unpacked at this stage.