Skip to content

How to speed up pullbacks when iterating over arrays? #644

@sethaxen

Description

@sethaxen

Consider the following example from base Julia:

evalpoly(x, p::AbstractVector) = _evalpoly(x, p)

function _evalpoly(x, p)
    N = length(p)
    ex = p[end]
    for i in N-1:-1:1
        ex = muladd(x, ex, p[i])
    end
    ex
end

This function is simple and looks at each element of p once. Ideally, its gradient would be fast. However:

julia> using Zygote, BenchmarkTools

julia> x, p = rand(), randn(10000);

julia> @btime evalpoly(x, p);
  17.323 μs (1 allocation: 16 bytes)

julia> @btime Zygote.gradient(evalpoly, x, p);
  679.433 ms (490088 allocations: 1.51 GiB)

To work around this, I'm adding a custom rule for evalpoly to ChainRules.jl (JuliaDiff/ChainRules.jl#190), that will speed things up dramatically once #366 is merged:

julia> @btime Zygote.gradient(evalpoly, x, p); #10,000x faster than before!!
  43.338 μs (21 allocations: 156.81 KiB)

But this is a band-aid. How can we improve this in the general case? Is the main problem likely the adjoint for getindex (e.g. #365)? But even with a great adjoint for getindex, Zygote won't know that each element is only used once and therefore it could do something fast like allocate a single gradient vector and fill it efficiently.

For reference, jax is able to get comparable performance to the custom rule without a rule (see jax-ml/jax#3047 (comment)).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions