Skip to content

Compatibility with ProtoStruct.jl, and LayerFactory ideas for custom layers #2107

@MilesCranmer

Description

@MilesCranmer

There's this really nice package ProtoStruct.jl that lets you create structs which can be revised. I think this is extremely useful for developing custom models in Flux.jl using Revise.jl, since otherwise I would need to restart every time I want to add a new property in my model.

Essentially the way it works is to transform:

@proto struct MyLayer
    chain1::Chain
    chain2::Chain
end

into (regardless of the properties)

struct MyLayer{NT<:NamedTuple}
    properties::NT
end

and, inside the macro, set up constructors based on your current defined properties.

However, right now it doesn't work with Flux.jl. When I try to get the parameters from a model, I see the error: NamedTuple has no field properties. Here is a MWE:

# Desired API for FluxML
using Flux
using Flux: params, @functor
using ProtoStructs

@proto struct ResidualDense
    w1::Dense
    w2::Dense
    act::Function
end

"""Residual layer."""
function (r::ResidualDense)(x)
    dx = r.w2(r.act(r.w1(x)))
    return r.act(dx + x)
end

@functor ResidualDense

function ResidualDense(in, out; hidden=128, act=relu)
    ResidualDense(Dense(in, hidden), Dense(hidden, out), act)
end

# Chain of linear layers:
mlp = Chain(
    Dense(5 => 128),
    ResidualDense(128, 128),
    ResidualDense(128, 128),
    Dense(128 => 1),
);

p = params(mlp);  # Fails

and here is the error:

ERROR: type NamedTuple has no field properties
Stacktrace:
  [1] getproperty
    @ ./Base.jl:38 [inlined]
  [2] getproperty(o::ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, s::Symbol)
    @ Main ~/.julia/packages/ProtoStructs/4sIVY/src/ProtoStruct.jl:134
  [3] functor(#unused#::Type{ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}}, x::ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}})
    @ Main ~/.julia/packages/Functors/V2McK/src/functor.jl:19
  [4] functor(x::ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}})
    @ Functors ~/.julia/packages/Functors/V2McK/src/functor.jl:3
  [5] trainable(x::ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}})
    @ Optimisers ~/.julia/packages/Optimisers/GKFy2/src/interface.jl:153
  [6] params!
    @ ~/.julia/packages/Flux/nJ0IB/src/functor.jl:46 [inlined]
  [7] params!(p::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, x::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, seen::Base.IdSet{Any}) (repeats 3 times)
    @ Flux ~/.julia/packages/Flux/nJ0IB/src/functor.jl:47
  [8] params!
    @ ~/.julia/packages/Flux/nJ0IB/src/functor.jl:40 [inlined]
  [9] params(m::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, ResidualDense{NamedTuple{(:w1, :w2, :act), Tuple{Dense, Dense, Function}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
    @ Flux ~/.julia/packages/Flux/nJ0IB/src/functor.jl:87
 [10] top-level scope
    @ ~/desired_julia_api/nice_api.jl:37

How hard would it be to make this compatible? I think it would be extremely useful to be able to quickly revise model definitions!


(Sorry for the spam today, by the way)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions