-
-
Notifications
You must be signed in to change notification settings - Fork 617
Description
There have been several issues/PRs related to freezing model parameters:
- freeze parameters #1022
- How to keep weights of parts of a model fixed under Flux.train! #1001
- Implement APIs of freeze parameters and freeze layers #1101
- delete! for Params Zygote.jl#505
- Per-leaf freezing Optimisers.jl#49
Right now, the recommendation made in the documentation is to manually specify which parameters should not be trained using some combination of Flux.params and Zygote.delete!.
While this works, it is somewhat inflexible in several respects:
- Training routines must be aware of the model architecture in order to select which parameters to freeze
- Specifying that a layer is frozen is often much more convenient to do at model construction time, particularly if the frozen layer is nested deeply inside the model
- It is not clear how the current approach would fit into the functional-style approach which is coming in v0.13, since
Paramswould no longer be used at all (would one need to e.g.fmapover a model and somehow mark specific layers as frozen before passing togradient?)
For these reasons, I often find myself defining a Frozen layer (similar to #1001) which looks something like this:
using Flux
using Flux: @adjoint
struct Frozen{F}
f::F
end
Flux.@functor Frozen # need functor for e.g. `fmap`
Flux.trainable(::Frozen) = NamedTuple() # no trainable parameters
# Something like `whitebox_apply` is required to explicitly treat `f` as a "white box":
# propagate gradients through `f`, but treat `f` itself as a constant functor
(l::Frozen)(xs...) = whitebox_apply(l.f, xs...)
whitebox_apply(f, xs...) = f(xs...)
@adjoint function whitebox_apply(f, xs...)
y, J = Flux.pullback(f, xs...)
y, Δ -> (nothing, J(Δ)...)
endA frozen layer l::Frozen wraps a functor f and has two properties:
l(x) = f(x)is differentiable with respect tox(as opposed to e.g.l(x) = dropgrad(f(x))which would treatf(x)as constant)fis treated as a constant functor: gradients ofl(x)with respect to parameters internal tofreturn zero
Below is some test code to illustrate how this layer should behave:
Examples/tests
x = rand(Float32, 2)
l1 = Dense(2, 3, tanh)
l2 = Dense(3, 4, tanh)
l3 = Dense(4, 2, identity)
m0 = Chain(l1, l2, l3)
m1 = Chain(l1, Frozen(l2), l3) # identical to `m0` but with the middle layer frozen
p0 = Flux.params(m0)
p1 = Flux.params(m1)
pfree = Flux.params(l1, l3)
pfrozen = Flux.params(l2)
# Basics
@assert all(p ∈ p1 for p in pfree) # free params are present
@assert all(p ∉ p1 for p in pfrozen) # frozen params are not
∇p1 = gradient(() -> sum(m1(x)), pfrozen)
@assert all(∇p1[p] === nothing for p in pfrozen) # frozen params have zero gradients, even if passed to `gradient` explicitly
∇p1 = gradient(() -> sum(m1(x)), p1)
@assert all(haskey(∇p1, p) for p in pfree) # free params have gradients
@assert !any(haskey(∇p1, p) for p in pfrozen) # frozen params do not have gradients
∇p0 = gradient(() -> sum(m0(x)), p0)
@assert all(∇p0[p] ≈ ∇p1[p] for p in pfree) # gradients are equal for free params
# This loss is constant as a function of `pfree`: `m0` and `m1` co-vary exactly as `pfree` changes,
# and therefore the difference `m0(x) - m1(x)` is zero with zero gradient w.r.t. `pfree`.
# However, since `m1` is treated as a constant function of `pfrozen` but `m0` is not,
# the gradient of `m0(x) - m1(x)` is nonzero w.r.t. `pfrozen`.
loss = () -> sum(m0(x) - m1(x))
∇p0 = gradient(loss, p0)
@assert all(iszero(∇p0[p]) for p in pfree) # gradient == 0 for free parameters
@assert !any(iszero(∇p0[p]) for p in pfrozen) # gradient != 0 for frozen parameters
∇p1 = gradient(loss, p1)
@assert all(iszero(∇p1[p]) for p in pfree) # gradient == 0 for free parameters
@assert !any(haskey(∇p1, p) for p in pfrozen) # gradients not present for frozen parameters
@assert all(∇p0[p] ≈ ∇p1[p] for p in pfree) # gradients are equal for free paramsIf there is interest in including a layer like Frozen into Flux I would be happy to make a PR. Of course, if there is an easy way to do what I'm describing which I have overlooked, please do let me know and I'll close this issue.