Skip to content

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 6, 2021

This is a start on adapting ProjectTo to address issues raised in JuliaDiff/ChainRules.jl#467, by means of storing an abstract type as a label to indicate the appropriate subspace.

This idea may have holes in it; I think that it would be worth trying to think up concrete adversarial examples to find them.

One issue we discussed is how Diagonal should be projected onto Symmetric, and #390 is some ideas to formalise that sort of thing. The behaviour here is: was:

julia> d = ProjectTo(Diagonal([1,2,3]));  # Int promoted to Float64, could argue for Real?

julia> s = ProjectTo(Symmetric(rand(3,3)))
ProjectTo{Symmetric{Float64, AbstractMatrix{Float64}}}(uplo = :U, parent = ProjectTo{AbstractMatrix{Float64}}(element = ProjectTo{Float64}(), axes = (Base.OneTo(3), Base.OneTo(3))))

julia> s(d(reshape(1:9,3,3)))  # could argue that Diagonal is better, valid subspace? or worse, wrong wrapper?
3×3 Symmetric{Float64, Diagonal{Float64, Vector{Float64}}}:
 1.0  0.0  0.0
 0.0  5.0  0.0
 0.0  0.0  9.0

julia> s([1 1 1; 2 2 2; 100 3 3])
3×3 Symmetric{Float64, Matrix{Float64}}:
  1.0  1.5  50.5
  1.5  2.0   2.5
 50.5  2.5   3.0

Edit -- now:

julia> s = ProjectTo(Symmetric(rand(3,3)))
ProjectTo{Symmetric}(uplo = :U, parent = ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(3), Base.OneTo(3))))

julia> s(d(reshape(1:9,3,3)))  # could argue that Diagonal is better, valid subspace? or worse, wrong wrapper?
3×3 Diagonal{Float64, Vector{Float64}}:
 1.0   ⋅    ⋅ 
  ⋅   5.0   ⋅ 
  ⋅    ⋅   9.0

At the end of this file, there is also a very crude sketch of a proj_rrule which applies projection to all arguments. The idea being roughly that AD should call this, while users should overload rrule; but they could opt out of projection by overloading proj_rrule. It should pass through @thunk but wil ldestroy inplacethunks.

Tests will certainly fail, almost everything I've tried is pasted into the docstring (for now).

@mcabbott
Copy link
Member Author

mcabbott commented Jul 6, 2021

What this doesn't affect is the handling of other structs, which retains the completely strict behaviour. That seems tricky to solve, for any scheme that allows array/number types not to match exactly.

Current behaviour:

julia> struct Two{T,S}; x::T; y::S; end

julia> tw = Two(1, [1,2,3]');

julia> Tangent{typeof(tw)}(x=tw.x, y=tw.y)
Tangent{Two{Int64, Adjoint{Int64, Vector{Int64}}}}(x = 1, y = [1 2 3])

julia> p = ChainRulesCore.generic_projectto(tw)  # generic_projectto exists, but not used generically
ProjectTo{Two}(x = ProjectTo{Float64}(), y = ProjectTo{Adjoint}(parent = ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(3),)),))

julia> p(Tangent{typeof(tw)}(x=1+im, y=ones(1,3)))
Two{Float64, Adjoint{Float64, Vector{Float64}}}(1.0, [1.0 1.0 1.0])

julia> ans.y  # should be floats
1×3 adjoint(::Vector{Float64}) with eltype Float64:
 1.0  1.0  1.0

Comment on lines 209 to 256
eltype(x) == Bool && return ProjectTo(false)
sub = ProjectTo(parent(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eltype(x) == Bool && return ProjectTo(false)
sub = ProjectTo(parent(x))
sub = ProjectTo(parent(x))
# if our parent is going to zero then we are also going to zero
sub <: ProjectTo{<:AbstractZero} && return sub

And this will be fine because was have similar logic for the ProjectTo on parent, so eventually will check the eltype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I thought. But when I tried, some of the wrappers were not happy to propagate Zero.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But when I tried, some of the wrappers were not happy to propagate Zero.

Do we just need to fix those ones?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I tried that too. But it was late & didn't seem a priority.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave a TODO and open an issue about this, and cross link it here?
Or make this work

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW there are still shortcuts (and tests) which work around things like this:

julia> Diagonal(NoTangent())
ERROR: MethodError: no method matching Diagonal(::NoTangent)

julia> Symmetric(NoTangent())
ERROR: MethodError: no method matching Symmetric(::NoTangent)

Not sure where such methods ought to be defined, if they should be.

@oxinabox
Copy link
Member

oxinabox commented Jul 7, 2021

I think this PR would be much enhanced by adding in the various things that can be the identity projection based on subspaces discussed in JuliaDiff/ChainRules.jl#467,
like ProjectTo{Symmetric{...}}(dx::Diagonal) = Diagonal.

Which is what using abstract types should make simpler, right? That is why we are doign that?

@mcabbott
Copy link
Member Author

mcabbott commented Jul 7, 2021

Which is what using abstract types should make simpler, right?

If I'm thinking correctly, T <: S is sufficient but not necessary for dx::T to be in a subspace of dy::S. Those are the easy cases, like accepting dx::Diagonal when x::Matrix by storing S::AbstractMatrix in the projector.

adding ProjectTo{Symmetric{...}}(dx::Diagonal) = dx [surely]

The case you mention is a subspace which is not a subtype. It would be easy to just add such methods, by hand, for known pairs [Edit -- f625b68 does this]. Perhaps this could be souped-up to something with a data structure it loops through. I wonder a bit whether this complication is necessary -- how often will you actually manage to produce a combination like this, and how often the downstream methods will be more efficient for Diagonal than for Symmetric{..., Diagonal}. Are there more compelling examples?

But more generally, are there other weird edge case which it might be difficult or impossible to bend this structure to fit? I think this is the important thing to try to figure out first. There is quite a bit of tidying up which can evidently be done here, but if there are unfixable problems then let's find them as soon as we can.

@oxinabox
Copy link
Member

oxinabox commented Jul 7, 2021

The case you mention is a subspace which is not a subtype.

Pretty much all subspaces are not subtypes, with the soul exception for basically all our subspace representing array types being a subtype of AbstractArray.
Which to be fair is the one that saves us a lot, and the most often.

Perhaps this could be souped-up to something with a data structure it loops through. I wonder a bit whether this complication is necessary -- how often will you actually manage to produce a combination like this, and how often the downstream methods will be more efficient for Diagonal than for Symmetric{..., Diagonal}. Are there more compelling examples?

#390 looks super gross, so yeah I would rather not do a big loop over things.

But more generally, are there other weird edge case which it might be difficult or impossible to bend this structure to fit? I think this is the important thing to try to figure out first.

  • The assumption that anything we don't have a more specific projector defined for is dense is not great, but might be the best compromise. Not completely convince we can't do something else though that uses the constructor. See the example here Use abstract types for projection #391 (comment)
  • Projecting everything all the time (rather than just things the rule authors chose to) makes for a bunch of hard edge cases
  • To make InplacableThunks work (had it not broken them) we would still need to have separate rules for (Strided?)Array and AbstractArray. Which does make it feel a bit a silly to be projecting when we are in a case we know we don't have to. But it should be ok since
  • It's going to complicate the way ADs connect to this per https://github.com/JuliaDiff/ChainRulesCore.jl/pull/391/files#r665592409

@oxinabox
Copy link
Member

oxinabox commented Jul 7, 2021

Putting on my ChainRules project leader's hat. and making the hard calls.
Bearing in mind we have imminent deadlines, because we need to actually have the JuliaCon talk that teachs people how to use this stuff done and recorded in not so many days from now (this is with the extension).

All in all, my conclusion is that we should go ahead with the change to make things that are subspaces pass through with just the identity transform, by having rules for AbstractArray.
And that yes, we can assume all AbtractArrays are indeed dense-unstructured, until told otherwise.
(its wrong but better than suffering)
This is nonbreaking change, and more a optimization than anything else
The API remains the same as it is now, its just a bunch of copies etc get removed.

And we can fix up a bunch of behavour around scalars.
Those are bug-fixes

I am undecide if we should keep ProjectTo{T}(::AbstractZero)::T
Removing that is a breaking change, but it isn't a big deal.
What do you think? I think getting rid of that probably simplifies the Arrays of Arrays cases?

But we should leave off the proj_rrule stuff, and stick to doing it like in JuliaDiff/ChainRules.jl#459
where the rule author chooses to do so.
I agree with you it is less ergonomic.
But it puts full control in the user's hands, and it is working now.
And it takes out some of the more challenging cases from this PR.
It keeps the conceptual model the same with just 1 notion of a rrule, and avoids the confusions around overloading rrule vs overloading proj_rrule etc,.

Later we can add a macro to make writing it easier (like you say here JuliaDiff/ChainRules.jl#459 (comment)), and we can probably make that macro configurable in interesting ways (like have it look at the signature and only transform the arguments that are AbstractArray if that is what the user wants).
We also later might look into other ways of solving the problem that can be used in certain circumstances, like structural-zero-preserving-map.

I know it is a stereotype, to express appreciation after laying down the law,
but I really do appreciate the work you have put in, not just in this PR but in all the discussion leading up to it and like everything else.

If timelines were different we might be able to continue discussing and playing around with ideas, but they are not.
There is always ChainRulesCore 2.0 in a few years, and we might still be able to land something like proj_rrule as a optional feature during the 1.x period.
But it's not going to make it into 1.0

@mcabbott
Copy link
Member Author

mcabbott commented Jul 7, 2021

Projecting everything all the time (rather than just things the rule authors chose to)

Note that this is a choice to be made, no matter the precise implementation of the projector mechanism. But it does influence how this should be designed.

The proposal of FluxML/Zygote.jl#965 was to apply this everywhere, and correspondingly, the projector was pretty forgiving, essentially only corrects known problems and lets unknown types through. The projector defined in #385 is the opposite, completely inflexible about the type of the gradient, for arbitrary structs; this probably means it has to be applied pretty selectively. This PR is somewhere between, and it's not so clear yet how it ought to treat generic structs.

It's going to complicate the way ADs connect to this

This complication is about the particular idea of having two different rrule functions, one of which adds projectors to the other, and then it goes to AD. That's not the only mechanism by which projectors could be applied globally, if we wanted to do that. FluxML/Zygote.jl#965 had a different one.

And if we don't apply them globally, there are other ways we could slash the boilerplate budget compared to the present state of JuliaDiff/ChainRules.jl#459.

breaking InplacableThunks

Once we have tangents like dx::Diagonal, or mixes of eltypes, then I believe InplacableThunks are already broken. To a large degree this problem has to be solved regardless, perhaps by restricting what types they are ever applied to (such as dx::StridedArray{OK_type})

Comments are here #391 (comment) but should be an issue I think. Not going to do anything about it here for now.

Pretty much all subspaces are not subtypes

I'm pretty sure these are both infinite sets. What I mean by "compelling examples " is actual concrete types to think about. There are trade-offs between mathematical generality and simplicity, and I'm not convinced we need to design around this at all.

@mcabbott
Copy link
Member Author

mcabbott commented Jul 7, 2021

For now I will

  • Delete the proj_rrule experiment, leave the rule to be applied by hand (or within AD, e.g. broadcasting) 8e400fc
  • Restrict ProcectTo to act only on AbstractArray & Number, everything else an error (for now). Then any clever recursive behaviour on structs can be non-breaking post-1.0
  • Keep the quite strict number type, oftype(float(x), dx) We should try out some Forward-over-Reverse things before tagging. Changed to preserve simple HW floats, apply float(::Integer), and otherwise use Real etc.
  • Keep the AbstractZero pass-through.
  • Treat arrays of arrays as before, and make sparse arrays preserve sparsity -- Done, probably slow & possibly buggy.

Some questions:

  • Should ProjectTo(fill(1))(pi) work? And ProjectTo(Ref(1))(pi)? These seem like examples of subspace-but-not-subtype which are likely to occur, e.g. gradient(x -> x .+ 1, Ref(2)).
  • Should @scalar_rule insert projectors? This case is much simpler than doing all rrules. Not this PR though. Done.
  • Without doing all rrules, at least the user-facing Zygote.gradient, and its broadcasting, should probably always apply ProjectTo (to allowed argument types).

Now solves all the issues at JuliaDiff/ChainRules.jl#467 (comment), except it's more strict about what it reshapes, and UpperTriangular(Diagonal( isn't simplified.

@codecov-commenter
Copy link

codecov-commenter commented Jul 8, 2021

Codecov Report

Merging #391 (0dedf92) into master (38d3929) will increase coverage by 0.21%.
The diff coverage is 96.21%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #391      +/-   ##
==========================================
+ Coverage   91.95%   92.17%   +0.21%     
==========================================
  Files          15       14       -1     
  Lines         634      754     +120     
==========================================
+ Hits          583      695     +112     
- Misses         51       59       +8     
Impacted Files Coverage Δ
src/projection.jl 94.58% <96.21%> (-1.20%) ⬇️
src/differentials/thunks.jl 94.89% <0.00%> (ø)
src/deprecated.jl

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 38d3929...0dedf92. Read the comment docs.

@MasonProtter
Copy link
Contributor

I wonder if instead of float(x) if there could be some continuum(x) or differentiable(x) interface. E.g. it might be more natural to turn certain types into Rational instead of Float depending on the user’s wishes.

@mcabbott
Copy link
Member Author

mcabbott commented Jul 8, 2021

They could overload ProjectTo(x::WeirdNumber) to do this, I think.

The current setup will kill rationals. It could be ProjectTo(x::Integer) = ...(float(x)) to target this more narrowly. It's nice that ProjectTo(pi)(1) converts to floats now, could overload irrationals too of course.

@MasonProtter
Copy link
Contributor

My naïve suggestion would be that I’d like to see Rationals not get killed when possible. As far as I’m aware, they’re a perfectly good differentiable field (so long as you don’t hit transcendental functions), and could be very nice to use in some cases. But I’m not totally sure what the corresponding trade offs are

@mcabbott
Copy link
Member Author

Bump. If we're going to do roughly this, shall we merge this PR & then do whatever fiddling necessary on top of it? I made a branch to start applying this to Zygote, to try out...

BTW, re sparse arrays, what's done here is a very naiive first cut. But did relevant people see this?

https://discourse.julialang.org/t/chinas-alternative-to-gsoc-2021-14-julia-related-programs/64473

https://summer.iscas.ac.cn/#/org/prodetail/210370152?lang=en

@oxinabox
Copy link
Member

I will review tomorrow.
Looks like something that changed in @scalar_rule is breaking the ChainRules.jl tests.
Probably need to be fixed before we can merge.

@mcabbott
Copy link
Member Author

It changes some return types, so breaking very strict tests of previous behaviour inevitable:

(∂x, ∂y) isa Tuple{T, T}
   Evaluated: (1.1475439f0, -3.144237995147705) isa Tuple{Float32,Float32}

Whether it breaks anything more of course deserves a close look.

@oxinabox
Copy link
Member

Is the new result more correct?

@mcabbott
Copy link
Member Author

mcabbott commented Jul 12, 2021

The ChainRules failure is from things like _, ∂x, ∂y = rrule(/, 3.14f0, 2)[2](Δz), here:
https://github.com/JuliaDiff/ChainRules.jl/blob/master/test/rulesets/Base/fastmath_able.jl#L151-L154
The original issue was that x::Float32 was once upon a time accidentally promoted to ∂x::Float64 by integer division. That's not broken. With y::Int, the test also demands that ∂y::Float32, while any projection rule that calls float on integers will give ∂y::Float64.

Nevertheless, it's possible that you may want to delay the changes to @scalar_rule for a breaking release. Your call I guess. I didn't originally intend to put them in this PR but started fiddling on the same branch... Now #395 is that change.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have to go now, I will finish reviewing this evening
Here are my notes so far

Comment on lines 101 to 102
ProjectTo() = ProjectTo{Any}() # trivial case, exists so that maybe_call(f, x) knows what to do
(x::ProjectTo{Any})(dx) = dx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this?

Suggested change
ProjectTo() = ProjectTo{Any}() # trivial case, exists so that maybe_call(f, x) knows what to do
(x::ProjectTo{Any})(dx) = dx

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exists so that containers of unknown types can be handled consistently.

E.g. for x::Ref, you always get ProjectTo{Ref}(; x = ProjectTo..., and you always call p.x(dx) when projecting. But if x[] is some weird struct, then my first attempt stored p.dx == identity. But then maybe_call(f, x) doesn't know whether this identity is a part of the original struct which has been saved, or an indicator of the trivial projection.

Copy link
Member

@oxinabox oxinabox Jul 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
I guess this is safe enough.

Though if we stop trying to do things recursively (which I am infavor of for now)
we could get rid of it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think going into Ref is a good idea. This is the preferred way to store one mutable parameter, right? And shows up a lot for broadcasting etc. too, although I don't know whether we end up involved ever.

@mcabbott mcabbott changed the title RFC: use abstract types for projection Use abstract types for projection Jul 13, 2021
@oxinabox
Copy link
Member

Thanks, huge PR, great work.
Anything further can come in a follow up PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants