Skip to content

Backprop through time #648

@MikeInnes

Description

@MikeInnes

Continuing our series "cool things we can't have yet", and inspired by this comment I was thinking about how we'll expose BPTT. Currently, given a forward pass like this:

for word in seq
  loss += model(word)
end
loss

If we don't want to backprop over the whole sequence at once (gradient outside the loop) or over only a single step at a time (gradient inside the loop) then we need to split the loop as follows:

for chunks in seq
  θ̄ = gradient() do
    for word in chunk
      loss += model(word)
    end
    loss
  end
end

An alternative to this is to just expose primitives that let us fiddle with time steps directly. Consider:

record = Capacitor(5)
for x in seq
  θ̄ = gradient() do
    record() do
      model(x)
    end
  end
end

Alright, bear with me here. This is written as if we were backprop-ing only across a single time step at a time, but with model evaluation wrapped in record. The idea is that record will log 5 previous backpropagators for the closure it is passed, and then chain these together for the backwards pass, which means we can actually backpropagate through n previous iterations of the loop -- i.e. backpropagation through time.

What's cool about this is that it makes BPTT completely orthogonal to the structure of the forward pass. The recorder can equally well be set up to backprop the last n steps each iteration (sliding window BTTF) or only every nth iteration (normal BTTF), or anything in between, and this can be set up differently for different parts of the model. It also isn't specific to any particular RNN implementation, e.g. this will work even though we have to backprop through h over loop iterations:

record = Capacitor(5)
h = ...
for word in seq
  θ̄ = gradient() do
    record() do
      y, h = model(word)
      loss(word, y)
    end
  end
end

The main question is whether this is actually going to be intuitive for people (who aren't travelling at 88mph). If it looks weird right now I think that's partly because we're not used to using gradient this way, so getting used to that will make the extra feature easier to reason about. At least for sliding windows, I think it's strictly better than flow-based alternatives.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions