33 @layer Dense
44 @layer :expand Chain
55 @layer BatchNorm trainable=(β,γ)
6- @layer Struct functor =(α,β) trainable=(β,)
6+ @layer Struct children =(α,β) trainable=(β,)
77
88This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
99When you define a new layer, this tells Flux to explore inside it
1010to see the parameters it trains, and also to move them to the GPU, change precision, etc.
1111
1212Some "keywords" allow control of the recursion:
1313* If some fields look like parameters but should not be trained,
14- then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
15- * We can likewise add restructions to `Functors.functor`, but not yet written.
16- * In fact you can provide an arbitrary keyword with this syntax, and it will
17- overload this function alla `trainable`... that might be a terrible idea.
14+ then `trainable` lets you specify fields to include, and ignore the rest.
15+ * We can likewise add restructions to Functors's `children`,
16+ but this is not yet written (as this is seldom a good idea).
1817
1918It also handles overloads of `show` for pretty printing.
2019* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
2120* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
22- * To disable all `show` overloads, maybe we want a `:ignore` option too.
21+ * To disable all `show` overloads, there is an `:ignore` option too.
2322
2423(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
24+
25+ Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
26+
27+ # Example
28+ ```jldoctest
29+ julia> struct Trio; a; b; c end
30+
31+ julia> tri = Trio(Dense([1.1 2.2],), Dense([3.3;;], false), Dropout(0.4))
32+ Trio(Dense(1 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
33+
34+ julia> Flux.destructure(tri) # parameters not visible to Flux
35+ (Bool[], Restructure(Trio, ..., 0))
36+
37+ julia> Flux.@layer :expand Trio
38+
39+ julia> Flux.destructure(tri) # now gpu, train!, etc will see inside too
40+ ([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
41+
42+ julia> tri
43+ Trio(
44+ Dense(2 => 1), # 3 parameters
45+ Dense(1 => 1; bias=false), # 1 parameters
46+ Dropout(0.4),
47+ ) # Total: 3 arrays, 4 parameters, 224 bytes.
48+ ```
49+
2550"""
2651macro layer (exs... )
2752 out = quote end
@@ -40,10 +65,10 @@ macro layer(exs...)
4065 end
4166
4267 # This function exists only for depwarns when you use @functor directly
43- push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing )) # scope is weird ?? can't use $ on func name?
68+ push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing ))
4469
45- i = findfirst (ex -> Meta. isexpr (ex, :(= )) && ex. args[1 ] == :functor , rest)
46- if isnothing (i)
70+ i = findfirst (ex -> Meta. isexpr (ex, :(= )) && ex. args[1 ] == :children , rest)
71+ if isnothing (i) # then default like @functor Layer
4772 push! (out. args, _macro_functor (esc (type)))
4873 else
4974 push! (out. args, _macro_functor (esc (type), rest[i]. args[2 ]))
@@ -52,13 +77,14 @@ macro layer(exs...)
5277 j == i && continue
5378 ex = rest[j]
5479 Meta. isexpr (ex, :(= )) || error (" expected keyword = fields" )
55- if ex. args[1 ] == :trainable
56- push! (out. args, _macro_trainable (type, trainable, ex. args[2 ])) # pass the function "trainable" not the symbol
80+
81+ name = if ex. args[1 ] == :trainable
82+ :(Optimisers. trainable)
5783 else
58- error ()
59- # @warn "defining a method for $(ex.args[1]) in your scope" # ??
60- # push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
84+ @warn " trying to define a method for `$(ex. args[1 ]) ` in your scope... this is experimental" maxlog= 1
85+ esc (ex. args[1 ])
6186 end
87+ push! (out. args, _macro_trainable (esc (type), name, ex. args[2 ]))
6288 end
6389
6490 out
@@ -72,17 +98,16 @@ function _check_new_macro(x::T) where T
7298end
7399_check_new_macro (:: Tuple ) = nothing # defined by Functors.jl, not by users
74100_check_new_macro (:: NamedTuple ) = nothing
75- _check_new_macro (:: Transpose ) = nothing
76- _check_new_macro (:: Adjoint ) = nothing
101+ _check_new_macro (:: AbstractArray ) = nothing
77102_check_new_macro (:: Ref ) = nothing
78103
79104# @layer's code for Functors & Adapt
80105# Unlike @functor, _default_functor doesn't need to eval anything
81106
82107function _macro_functor (type)
83108 quote
84- Functors. functor (:: Type{T} , x) where {T<: $type } = _default_functor (T, x)
85- Adapt. adapt_structure (to, layer:: $type ) = fmap (adapt (to), layer)
109+ Functors. functor (:: Type{T} , x) where {T<: $type } = $ _default_functor (T, x)
110+ Adapt. adapt_structure (to, layer:: $type ) = $ fmap ($ adapt (to), layer)
86111 end
87112end
88113
@@ -94,12 +119,13 @@ function _default_functor(::Type{T}, x) where {T}
94119 if @generated
95120 F = fieldnames (T)
96121 args = map (sy -> :(getfield (x, $ (QuoteNode (sy)))), F)
97- C = Base. typename (T). name # constructor
122+ C = Base. typename (T). wrapper # constructor
98123 recon = VERSION > v " 1.9-" ? :(Splat ($ C)) : :(Base. splat ($ C))
99124 :((NamedTuple {$F} (($ (args... ),)), $ recon))
100125 else
101126 # Getting this parameterless type takes about 2μs, every time:
102- namedtuple (x), Base. splat (Base. typename (T). wrapper)
127+ spl = VERSION > v " 1.9-" ? Splat : Base. splat
128+ namedtuple (x), spl (Base. typename (T). wrapper)
103129 end
104130end
105131
@@ -117,61 +143,12 @@ function _macro_trainable(type, fun, fields)
117143 quoted = map (QuoteNode, symbols)
118144 gets = [:(getfield (x, $ f)) for f in quoted]
119145 quote
120- # $fun(x::$type) = NamedTuple{$names }(($(gets...),))
121- Flux. trainable (x:: $type ) = NamedTuple {$symbols} (($ (gets... ),)) # ?? scope is weird
146+ $ fun (x:: $type ) = NamedTuple {$symbols } (($ (gets... ),))
147+ # Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
122148 end
123149end
124150_macro_trainable (type, fun, field:: Union{Symbol,QuoteNode} ) = _macro_trainable (type, fun, :(($ field,))) # lets you forget a comma
125151
126152_noquotenode (s:: Symbol ) = s
127153_noquotenode (q:: QuoteNode ) = q. value # lets you write trainable=(:x,:y) instead of (x,y)
128154_noquotenode (ex) = error (" expected a symbol, got $ex " )
129-
130-
131-
132-
133-
134-
135- # @big_show Chain
136- # @big_show Parallel
137- # @big_show SkipConnection
138- # @big_show Recur
139- # @big_show Maxout
140-
141-
142-
143-
144- """
145- @big_show MyContainer
146-
147- This macro lets you opt-in to Flux's fancy printing.
148-
149- When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
150- and the printing routine will recursively unfold its children.
151- This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
152-
153- Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
154- need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
155-
156- # Example
157- ```jldoctest
158- julia> struct Trio{A,B,C}; a::A; b::B; c::C end
159-
160- julia> Flux.@functor Trio
161-
162- julia> Flux.@big_show Trio
163-
164- julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
165- Trio(
166- Dense(10 => 5, tanh), # 55 parameters
167- Dense(5 => 2), # 12 parameters
168- NNlib.softmax,
169- ) # Total: 4 arrays, 67 parameters, 492 bytes.
170- ```
171-
172- Note that there is no automatic method for 2-arg `show`, and thus
173- something like `(tri, tri)` will print all the type parameters.
174-
175- However, `Chain(tri, tri)` will always use Flux's recursive printing,
176- even without using this macro: `Chain` is the entry point.
177- """
0 commit comments