-
-
Notifications
You must be signed in to change notification settings - Fork 216
Use ProjectTo
in broadcasting & gradient
#1044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
050ea52
a416263
ac1281b
0bb31c2
d087bbe
d7ce02f
f353ae2
48fbfcc
a826092
91fc91f
d905c3d
fbebbe9
502d85d
361d047
b621330
3e3e16e
ea54df7
ff5f20e
8599e1b
27e52b2
ff9aacf
5bf5342
e9ea88a
0013fd3
c07ae9f
7ff1159
6549c57
298f119
e3922a9
a2814ae
08f8c46
c8bc588
5080490
1b37161
4c08118
7197491
dde922b
1c07a7c
35280d5
80123a1
3bc2e09
02397b5
a3e3a97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,15 +68,20 @@ julia> gradient([7, 11], 0, 1) do x, y, d | |
p = size(x, d) | ||
sum(x.^p .+ y) | ||
end | ||
([14.0, 22.0], 2, nothing) | ||
([14.0, 22.0], 2.0, nothing) | ||
``` | ||
""" | ||
function gradient(f, args...) | ||
y, back = pullback(f, args...) | ||
return back(sensitivity(y)) | ||
grad = back(sensitivity(y)) | ||
isnothing(grad) ? nothing : map(_project, args, grad) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can add a method to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can you write exactly what method that would be? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is easy to try:
|
||
|
||
Base.adjoint(f::Function) = x -> gradient(f, x)[1] | ||
# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! | ||
Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for legacy reasons | ||
y, back = pullback(f, x) | ||
back(sensitivity(y))[1] | ||
end | ||
|
||
""" | ||
withgradient(f, args...) | ||
|
@@ -95,7 +100,9 @@ true | |
""" | ||
function withgradient(f, args...) | ||
y, back = pullback(f, args...) | ||
(val = y, grad = back(sensitivity(y))) | ||
grad = back(sensitivity(y)) | ||
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad) | ||
(val=y, grad=results) | ||
end | ||
|
||
# Param-style wrappers | ||
|
@@ -115,9 +122,9 @@ julia> g = gradient(Params([x, y])) do | |
Grads(...) | ||
|
||
julia> g[x] | ||
2×3 Matrix{Int64}: | ||
7 70 700 | ||
8 80 800 | ||
2×3 Matrix{Float64}: | ||
7.0 70.0 700.0 | ||
8.0 80.0 800.0 | ||
|
||
julia> haskey(g, z) # only x and y are parameters | ||
false | ||
|
@@ -144,6 +151,8 @@ Params(xs::Tuple) = Params(collect(xs)) | |
@forward Params.order Base.iterate, Base.length, Base.getindex | ||
@forward Params.params Base.in | ||
|
||
Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params) | ||
|
||
function Base.union!(ps::Params, itrs...) | ||
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs) | ||
return ps | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,18 +45,19 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr | |
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)}) | ||
end | ||
|
||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think doing this makes unbroadcast less generic, we don't need to define projections here afaict. Let's retain the current definition. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What case exactly is not handled, if this is less generic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It restricts it to what can be handled by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. those are broadly the same now, as of recent changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that before CRC changes, |
||
trim(x::Tuple, Δ) = NTuple{length(x)}(Δ) | ||
|
||
unbroadcast(x::AbstractArray, x̄) = | ||
size(x) == size(x̄) ? x̄ : | ||
length(x) == length(x̄) ? trim(x, x̄) : | ||
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) | ||
|
||
function unbroadcast(x::AbstractArray, x̄) | ||
N = ndims(x̄) | ||
if length(x) == length(x̄) | ||
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors | ||
else | ||
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄)) | ||
_project(x, accum_sum(x̄; dims = dims)) | ||
end | ||
end | ||
unbroadcast(x::Number, x̄) = accum_sum(x̄) | ||
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) | ||
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) | ||
unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 | ||
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 | ||
|
||
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.