Skip to content

How to keep weights of parts of a model fixed under Flux.train! #1001

@fps

Description

@fps

Hi, maybe I'm just too dense in reading the docs, but I couldn't quite figure out a nice way to do this. Let's say I have a model

m = Flux.Chain(Flux.Conv(...), Flux.Conv(...), Flux.Conv(...), etc.. pp)

and I wanted to keep the parameters of the first chain entry fixed under Flux.train! how would I best go about it? The first thought was to simply have a "wrapper" like

struct Fixed
  m
end
(f::Fixed)(x) = f.m(x)

which is explicitly not a Flux.@functor and just use that in m:

m = Flux.Chain(Fixed(Flux.Conv(...)), Flux.Conv(...), Flux.Conv(...), etc.. pp)

But sadly that breaks moving to the GPU. I tried reading the Flux.@functor code, but sadly it is a little bit over my head. I guess another approach would be to pick apart the Flux.params(m) return value and take out everything I don't want updated, but I'm sure there's a nicer way that composes more intuitively.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions