-
-
Notifications
You must be signed in to change notification settings - Fork 216
Description
Consider the following example from base Julia:
evalpoly(x, p::AbstractVector) = _evalpoly(x, p)
function _evalpoly(x, p)
N = length(p)
ex = p[end]
for i in N-1:-1:1
ex = muladd(x, ex, p[i])
end
ex
endThis function is simple and looks at each element of p once. Ideally, its gradient would be fast. However:
julia> using Zygote, BenchmarkTools
julia> x, p = rand(), randn(10000);
julia> @btime evalpoly(x, p);
17.323 μs (1 allocation: 16 bytes)
julia> @btime Zygote.gradient(evalpoly, x, p);
679.433 ms (490088 allocations: 1.51 GiB)To work around this, I'm adding a custom rule for evalpoly to ChainRules.jl (JuliaDiff/ChainRules.jl#190), that will speed things up dramatically once #366 is merged:
julia> @btime Zygote.gradient(evalpoly, x, p); #10,000x faster than before!!
43.338 μs (21 allocations: 156.81 KiB)But this is a band-aid. How can we improve this in the general case? Is the main problem likely the adjoint for getindex (e.g. #365)? But even with a great adjoint for getindex, Zygote won't know that each element is only used once and therefore it could do something fast like allocate a single gradient vector and fill it efficiently.
For reference, jax is able to get comparable performance to the custom rule without a rule (see jax-ml/jax#3047 (comment)).