-
-
Notifications
You must be signed in to change notification settings - Fork 617
Closed
Description
Hi, maybe I'm just too dense in reading the docs, but I couldn't quite figure out a nice way to do this. Let's say I have a model
m = Flux.Chain(Flux.Conv(...), Flux.Conv(...), Flux.Conv(...), etc.. pp)
and I wanted to keep the parameters of the first chain entry fixed under Flux.train! how would I best go about it? The first thought was to simply have a "wrapper" like
struct Fixed
m
end
(f::Fixed)(x) = f.m(x)
which is explicitly not a Flux.@functor and just use that in m:
m = Flux.Chain(Fixed(Flux.Conv(...)), Flux.Conv(...), Flux.Conv(...), etc.. pp)
But sadly that breaks moving to the GPU. I tried reading the Flux.@functor code, but sadly it is a little bit over my head. I guess another approach would be to pick apart the Flux.params(m) return value and take out everything I don't want updated, but I'm sure there's a nicer way that composes more intuitively.
Metadata
Metadata
Assignees
Labels
No labels