-
-
Notifications
You must be signed in to change notification settings - Fork 615
Description
Continuing our series "cool things we can't have yet", and inspired by this comment I was thinking about how we'll expose BPTT. Currently, given a forward pass like this:
for word in seq
loss += model(word)
end
loss
If we don't want to backprop over the whole sequence at once (gradient
outside the loop) or over only a single step at a time (gradient
inside the loop) then we need to split the loop as follows:
for chunks in seq
θ̄ = gradient() do
for word in chunk
loss += model(word)
end
loss
end
end
An alternative to this is to just expose primitives that let us fiddle with time steps directly. Consider:
record = Capacitor(5)
for x in seq
θ̄ = gradient() do
record() do
model(x)
end
end
end
Alright, bear with me here. This is written as if we were backprop-ing only across a single time step at a time, but with model evaluation wrapped in record
. The idea is that record
will log 5 previous backpropagators for the closure it is passed, and then chain these together for the backwards pass, which means we can actually backpropagate through n
previous iterations of the loop -- i.e. backpropagation through time.
What's cool about this is that it makes BPTT completely orthogonal to the structure of the forward pass. The recorder can equally well be set up to backprop the last n
steps each iteration (sliding window BTTF) or only every n
th iteration (normal BTTF), or anything in between, and this can be set up differently for different parts of the model. It also isn't specific to any particular RNN implementation, e.g. this will work even though we have to backprop through h
over loop iterations:
record = Capacitor(5)
h = ...
for word in seq
θ̄ = gradient() do
record() do
y, h = model(word)
loss(word, y)
end
end
end
The main question is whether this is actually going to be intuitive for people (who aren't travelling at 88mph). If it looks weird right now I think that's partly because we're not used to using gradient
this way, so getting used to that will make the extra feature easier to reason about. At least for sliding windows, I think it's strictly better than flow-based alternatives.