-
Notifications
You must be signed in to change notification settings - Fork 65
Use abstract types for projection #391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Current behaviour:
|
src/projection.jl
Outdated
eltype(x) == Bool && return ProjectTo(false) | ||
sub = ProjectTo(parent(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eltype(x) == Bool && return ProjectTo(false) | |
sub = ProjectTo(parent(x)) | |
sub = ProjectTo(parent(x)) | |
# if our parent is going to zero then we are also going to zero | |
sub <: ProjectTo{<:AbstractZero} && return sub |
And this will be fine because was have similar logic for the ProjectTo on parent, so eventually will check the eltype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I thought. But when I tried, some of the wrappers were not happy to propagate Zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But when I tried, some of the wrappers were not happy to propagate Zero.
Do we just need to fix those ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I tried that too. But it was late & didn't seem a priority.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leave a TODO and open an issue about this, and cross link it here?
Or make this work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW there are still shortcuts (and tests) which work around things like this:
julia> Diagonal(NoTangent())
ERROR: MethodError: no method matching Diagonal(::NoTangent)
julia> Symmetric(NoTangent())
ERROR: MethodError: no method matching Symmetric(::NoTangent)
Not sure where such methods ought to be defined, if they should be.
I think this PR would be much enhanced by adding in the various things that can be the identity projection based on subspaces discussed in JuliaDiff/ChainRules.jl#467, Which is what using abstract types should make simpler, right? That is why we are doign that? |
If I'm thinking correctly,
The case you mention is a subspace which is not a subtype. It would be easy to just add such methods, by hand, for known pairs [Edit -- f625b68 does this]. Perhaps this could be souped-up to something with a data structure it loops through. I wonder a bit whether this complication is necessary -- how often will you actually manage to produce a combination like this, and how often the downstream methods will be more efficient for Diagonal than for Symmetric{..., Diagonal}. Are there more compelling examples? But more generally, are there other weird edge case which it might be difficult or impossible to bend this structure to fit? I think this is the important thing to try to figure out first. There is quite a bit of tidying up which can evidently be done here, but if there are unfixable problems then let's find them as soon as we can. |
Pretty much all subspaces are not subtypes, with the soul exception for basically all our subspace representing array types being a subtype of AbstractArray.
#390 looks super gross, so yeah I would rather not do a big loop over things.
|
Putting on my ChainRules project leader's hat. and making the hard calls. All in all, my conclusion is that we should go ahead with the change to make things that are subspaces pass through with just the identity transform, by having rules for And we can fix up a bunch of behavour around scalars. I am undecide if we should keep But we should leave off the Later we can add a macro to make writing it easier (like you say here JuliaDiff/ChainRules.jl#459 (comment)), and we can probably make that macro configurable in interesting ways (like have it look at the signature and only transform the arguments that are AbstractArray if that is what the user wants). I know it is a stereotype, to express appreciation after laying down the law, If timelines were different we might be able to continue discussing and playing around with ideas, but they are not. |
Note that this is a choice to be made, no matter the precise implementation of the projector mechanism. But it does influence how this should be designed. The proposal of FluxML/Zygote.jl#965 was to apply this everywhere, and correspondingly, the projector was pretty forgiving, essentially only corrects known problems and lets unknown types through. The projector defined in #385 is the opposite, completely inflexible about the type of the gradient, for arbitrary structs; this probably means it has to be applied pretty selectively. This PR is somewhere between, and it's not so clear yet how it ought to treat generic structs.
This complication is about the particular idea of having two different And if we don't apply them globally, there are other ways we could slash the boilerplate budget compared to the present state of JuliaDiff/ChainRules.jl#459.
Once we have tangents like Comments are here #391 (comment) but should be an issue I think. Not going to do anything about it here for now.
I'm pretty sure these are both infinite sets. What I mean by "compelling examples " is actual concrete types to think about. There are trade-offs between mathematical generality and simplicity, and I'm not convinced we need to design around this at all. |
For now I will
Some questions:
Now solves all the issues at JuliaDiff/ChainRules.jl#467 (comment), except it's more strict about what it reshapes, and |
Codecov Report
@@ Coverage Diff @@
## master #391 +/- ##
==========================================
+ Coverage 91.95% 92.17% +0.21%
==========================================
Files 15 14 -1
Lines 634 754 +120
==========================================
+ Hits 583 695 +112
- Misses 51 59 +8
Continue to review full report at Codecov.
|
I wonder if instead of |
They could overload The current setup will kill rationals. It could be |
My naïve suggestion would be that I’d like to see |
Bump. If we're going to do roughly this, shall we merge this PR & then do whatever fiddling necessary on top of it? I made a branch to start applying this to Zygote, to try out... BTW, re sparse arrays, what's done here is a very naiive first cut. But did relevant people see this? https://discourse.julialang.org/t/chinas-alternative-to-gsoc-2021-14-julia-related-programs/64473 https://summer.iscas.ac.cn/#/org/prodetail/210370152?lang=en |
I will review tomorrow. |
It changes some return types, so breaking very strict tests of previous behaviour inevitable:
Whether it breaks anything more of course deserves a close look. |
Is the new result more correct? |
The ChainRules failure is from things like Nevertheless, it's possible that you may want to delay the changes to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I have to go now, I will finish reviewing this evening
Here are my notes so far
src/projection.jl
Outdated
ProjectTo() = ProjectTo{Any}() # trivial case, exists so that maybe_call(f, x) knows what to do | ||
(x::ProjectTo{Any})(dx) = dx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this?
ProjectTo() = ProjectTo{Any}() # trivial case, exists so that maybe_call(f, x) knows what to do | |
(x::ProjectTo{Any})(dx) = dx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This exists so that containers of unknown types can be handled consistently.
E.g. for x::Ref
, you always get ProjectTo{Ref}(; x = ProjectTo...
, and you always call p.x(dx)
when projecting. But if x[]
is some weird struct, then my first attempt stored p.dx == identity
. But then maybe_call(f, x)
doesn't know whether this identity
is a part of the original struct which has been saved, or an indicator of the trivial projection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
I guess this is safe enough.
Though if we stop trying to do things recursively (which I am infavor of for now)
we could get rid of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think going into Ref is a good idea. This is the preferred way to store one mutable parameter, right? And shows up a lot for broadcasting etc. too, although I don't know whether we end up involved ever.
Co-authored-by: Lyndon White <[email protected]>
This reverts commit 1a0c76d.
Thanks, huge PR, great work. |
This is a start on adapting
ProjectTo
to address issues raised in JuliaDiff/ChainRules.jl#467, by means of storing an abstract type as a label to indicate the appropriate subspace.This idea may have holes in it; I think that it would be worth trying to think up concrete adversarial examples to find them.
One issue we discussed is how Diagonal should be projected onto Symmetric, and #390 is some ideas to formalise that sort of thing. The behaviour here
is:was:Edit -- now:
At the end of this file, there is also a very crude sketch of a
proj_rrule
which applies projection to all arguments. The idea being roughly that AD should call this, while users should overloadrrule
; but they could opt out of projection by overloadingproj_rrule
. It should pass through@thunk
but wil ldestroy inplacethunks.Tests will certainly fail, almost everything I've tried is pasted into the docstring (for now).