@@ -12,3 +12,309 @@ function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, runnin
1212 end
1313 y, batchnorm_pullback
1414end
15+
16+ """
17+ norm_stats(x, dims)
18+
19+ Calculates sample mean and (uncorrected) variance of `x` along `dims`.
20+
21+ - `dims=(1,...,N-2,N)` for BatchNorm
22+ - `dims=(1,...,N-2)` for InstanceNorm and GroupNorm
23+ - `dims=(1,...,S)` where S < N for LayerNorm/Flux.jl/stable/
24+
25+ This is more efficient than calling `mean(x; dims)` and `var(x; dims)` separately,
26+ because it can share some computation across both.
27+ Implementors may want to overload this function to use custom kernels and more.
28+ """
29+ function norm_stats (x, dims)
30+ μ = mean (x; dims)
31+ σ² = var (x; dims, mean = μ, corrected = false )
32+ return μ, σ²
33+ end
34+
35+ function rrule (:: typeof (norm_stats), x, dims)
36+ μ, mean_pullback = rrule (mean, x; dims)
37+ σ², var_pullback = rrule (var, x; dims, mean = μ, corrected = false )
38+ function norm_stats_pullback (dargs)
39+ dμ, dσ² = unthunk (dargs)
40+ dx = ChainRulesCore. add!! (var_pullback (dμ)[2 ], mean_pullback (dσ²)[2 ])
41+ return (NoTangent (), dx, NoTangent ())
42+ end
43+ return (μ, σ²), norm_stats_pullback
44+ end
45+
46+ _maybe_reshape (:: Nothing , _) = nothing
47+ _maybe_reshape (x, dims) = reshape (x, dims)
48+ _apply_scale_bias (x, :: Nothing , :: Nothing ) = x
49+ _apply_scale_bias (x, scale, bias) = x .* scale .+ bias
50+
51+ """
52+ norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
53+ bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ))
54+
55+ Shared code path for all built-in norm functions.
56+
57+ `μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref),
58+ or extracted from an existing collection such as [`RunningStats`](@ref).
59+ `bias` and `scale` are consistent with cuDNN and Flux.Scale.
60+ We opt for `scale` over `weight` to avoid confusion with dense layers.
61+ If the size of the statistics and affine parameters differ,
62+ use `affine_size` to add padding dimensions as required to match the input.
63+ """
64+ function norm_helper (x, μ, σ², scale:: Union{AbstractArray, Nothing} ,
65+ bias:: Union{AbstractArray, Nothing} , ϵ:: Real , affine_size = size (μ))
66+ @ignore_derivatives if isnothing (scale) != isnothing (bias)
67+ error (" both scale and bias must be provided or left as nothing" )
68+ end
69+ scale′, bias′ = _maybe_reshape (scale, affine_size), _maybe_reshape (bias, affine_size)
70+ return _apply_scale_bias ((x .- μ) ./ sqrt .(σ² .+ ϵ), scale′, bias′)
71+ end
72+
73+ """
74+ RunningStats(mean, variance, momentum)
75+
76+ Contains running mean and variance estimates for stateful norm functions.
77+ `momentum` controls the strength of the moving average update.
78+
79+ If the parameters are mutable, they will be updated in-place.
80+ Otherwise, they will be replaced wholesale.
81+
82+ See also [`update_running_stats!`](@ref).
83+ """
84+ mutable struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
85+ mean:: M
86+ variance:: V
87+ momentum:: MT
88+ end
89+
90+ # Conditionally pulls running stats or calculates them on the fly.
91+ # Part of the reason this is a dedicated function is to have a more type stable pullback.
92+ function maybe_norm_stats (stats:: Union{RunningStats, Nothing} , x, dims,
93+ use_running_stats:: Bool )
94+ if stats != = nothing && use_running_stats
95+ # Maintains consistency with mean/var
96+ sz = Base. setindex (Base. reduced_indices (x, dims) |> Base. to_shape, :, ndims (x) - 1 )
97+ return reshape (stats. mean, sz), reshape (stats. variance, sz)
98+ end
99+ # No running stats exist or are disabled in inference mode
100+ return norm_stats (x, dims)
101+ end
102+
103+ # Kludge so we can close over a Union inner pullback type
104+ struct MaybeNormStatsPullback{B, P <: ProjectTo{AbstractArray} }
105+ back:: B
106+ projector:: P
107+ end
108+ function (pb:: MaybeNormStatsPullback )(dargs)
109+ _, dx = unthunk (pb. back (dargs))
110+ return (NoTangent (), NoTangent (), pb. projector (dx), NoTangent (), NoTangent ())
111+ end
112+ function rrule (:: typeof (maybe_norm_stats), stats:: Union{RunningStats, Nothing} , x, dims,
113+ use_running_stats:: Bool )
114+ project = ProjectTo (x)
115+ noop_back (_) = (NoTangent (), NoTangent ())
116+ if stats === nothing || ! use_running_stats
117+ (μ, σ²), back = rrule (norm_stats, x, dims)
118+ else
119+ # The default is to track, so this only happens when a layer is frozen
120+ sz = Base. setindex (Base. reduced_indices (x, dims) |> Base. to_shape, :, ndims (x) - 1 )
121+ μ, σ², back = reshape (stats. mean, sz), reshape (stats. variance, sz), noop_back
122+ end
123+ back_type = Union{typeof (noop_back), _rrule_pullback_rt (norm_stats, x, dims)}
124+ return (μ, σ²), MaybeNormStatsPullback {back_type, typeof(project)} (back, project)
125+ end
126+
127+ """
128+ update_running_stats!(stats::RunningStats, x::AbstractArray{<:Any, N}, μ, σ²,
129+ reduce_dims) where {N}
130+
131+ Performs a moving average update for layers with tracked statistics.
132+ `μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref).
133+ `reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref).
134+
135+ See also [`RunningStats`](@ref).
136+ """
137+ function update_running_stats! (stats:: RunningStats , x, μ, σ², reduce_dims:: Dims )
138+ V = eltype (σ²)
139+ momentum = stats. momentum
140+ res_mtm = one (V) - momentum
141+ m = prod (size (x, i) for i in reduce_dims)
142+ correction = m / (m - one (V))
143+
144+ running_mean, running_var = stats. mean, stats. variance
145+ if ChainRulesCore. is_inplaceable_destination (running_mean)
146+ stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
147+ else
148+ stats. mean = res_mtm .* running_mean .+ momentum .* vec (μ)
149+ end
150+ if ChainRulesCore. is_inplaceable_destination (running_var)
151+ stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
152+ else
153+ stats. variance = res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
154+ end
155+ end
156+
157+ # Convenience functions
158+ # We follow roughly the same arg order as torch.nn.functional.*_norm:
159+ # input, unique args for this particular norm type, bias + scale, eps; kwargs...
160+
161+ """
162+ layernorm(x::AbstractArray{<:Any,N}, ::Val{S}, scale = nothing, bias = nothing,
163+ ϵ=ofeltype(x, 1e-5)) where {N, S}
164+
165+ Functional [Layer Normalization](https://arxiv.org/abs/1607.06450) operation.
166+
167+ Normalizes `x` along the first `S` dimensions.
168+
169+ For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`.
170+
171+ See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
172+
173+ # Examples
174+
175+ ```jldoctest
176+ julia> using Statistics
177+
178+ julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels
179+
180+ julia> y = NNlib.layernorm(xs, Val(3));
181+
182+ julia> isapprox(std(y; dims = 1:3), ones(1, 1, 1, 2); atol = 0.1) &&
183+ std(y; dims = 1:3) != std(xs; dims = 1:3)
184+ true
185+ ```
186+ """
187+ function layernorm (x:: AbstractArray{<:Any, N} , :: Val{S} , scale = nothing , bias = nothing ,
188+ ϵ = ofeltype (x, 1e-5 )) where {N, S}
189+ @ignore_derivatives if S > N
190+ throw (DimensionMismatch (" got $S reduction dims for $N -dimensional array" ))
191+ end
192+ μ, σ² = norm_stats (x, ntuple (identity, S))
193+ return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S])
194+ end
195+
196+ """
197+ batchnorm(x::AbstractArray{<:Any, N},
198+ running_stats::Union{RunningStats, Nothing} = nothing,
199+ scale::Union{AbstractVector, Nothing} = nothing,
200+ bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
201+ training::Bool = within_grad()) where {N}
202+
203+ Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation.
204+
205+ Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice,
206+ where `N-1` is the "channel" (or "feature", for 2D inputs) dimension.
207+
208+ Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
209+ `batchnorm` will renormalize the input using these statistics during inference,
210+ and update them using batch-level statistics when training.
211+ To override this behaviour, manually set a value for `training`.
212+
213+ If specified, `scale` and `bias` will be applied as an additional learned affine transform.
214+
215+ See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
216+ """
217+ function batchnorm (x:: AbstractArray{<:Any, N} ,
218+ running_stats:: Union{RunningStats, Nothing} = nothing ,
219+ scale:: Union{AbstractVector, Nothing} = nothing ,
220+ bias:: Union{AbstractVector, Nothing} = nothing , ϵ = ofeltype (x, 1e-5 );
221+ training:: Bool = within_grad ()) where {N}
222+ reduce_dims = ((1 : (N - 2 )). .. , N)
223+ μ, σ² = maybe_norm_stats (running_stats, x, reduce_dims, ! training)
224+ # Because μ and σ² could be updated in-place, we compute the output first
225+ y = norm_helper (x, μ, σ², scale, bias, ϵ)
226+ @ignore_derivatives if running_stats != = nothing && training
227+ update_running_stats! (running_stats, x, μ, σ², reduce_dims)
228+ end
229+ return y
230+ end
231+
232+ """
233+ instancenorm(x::AbstractArray{<:Any, N},
234+ running_stats::Union{RunningStats, Nothing} = nothing,
235+ scale::Union{AbstractVector, Nothing} = nothing,
236+ bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
237+ training::Bool = within_grad()) where {N}
238+
239+ Functional [Instance Normalization](https://arxiv.org/abs/1607.08022) operation.
240+
241+ Normalizes `x` along each ``D_1×...×D_{N-2}×1×1`` input slice,
242+
243+ Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
244+ `instancenorm` will renormalize the input using these statistics during inference,
245+ and update them using batch-level statistics when training.
246+ To override this behaviour, manually set a value for `training`.
247+
248+ If specified, `scale` and `bias` will be applied as an additional learned affine transform.
249+
250+ See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref).
251+ """
252+ function instancenorm (x:: AbstractArray{<:Any, N} ,
253+ running_stats:: Union{RunningStats, Nothing} = nothing ,
254+ scale:: Union{AbstractVector, Nothing} = nothing ,
255+ bias:: Union{AbstractVector, Nothing} = nothing , ϵ = ofeltype (x, 1e-5 );
256+ training:: Bool = within_grad ()) where {N}
257+ affine_size = (ntuple (_ -> 1 , N - 2 )... , size (x, N - 1 ), :)
258+ reduce_dims = ((1 : (N - 2 )). .. ,)
259+ μ, σ² = maybe_norm_stats (running_stats, x, reduce_dims, ! training)
260+ # Because μ and σ² could be updated in-place, we compute the output first
261+ y = norm_helper (x, μ, σ², scale, bias, ϵ, affine_size)
262+ ChainRulesCore. @ignore_derivatives if running_stats != = nothing && training
263+ μ′, σ²′ = mean (μ; dims = N), mean (σ²; dims = N) # Need to sum (C, N) -> (C,)
264+ update_running_stats! (running_stats, x, μ′, σ²′, reduce_dims)
265+ end
266+ return y
267+ end
268+
269+ """
270+ groupnorm(x::AbstractArray{<:Any, N}, groups::Integer,
271+ scale::Union{AbstractVector, Nothing} = nothing,
272+ bias::Union{AbstractVector, Nothing} = nothing,
273+ ϵ = ofeltype(x, 1e-5)) where {N}
274+
275+ Functional [Group Normalization](https://arxiv.org/abs/1803.08494) operation.
276+
277+ Normalizes `x` along the first `N - 2` (spatial) dimensions,
278+ where `N-1` is the "channel" (or "feature", for 2D inputs) dimension,
279+ and the channel dimension is divided into `groups` groups along which statistics are computed.
280+ The number of channels must be an integer multiple of the number of groups.
281+
282+ If specified, `scale` and `bias` will be applied as an additional learned affine transform.
283+
284+ See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref).
285+
286+ # Examples
287+
288+ ```jldoctest
289+ julia> using Statistics
290+
291+ julia> xs = rand(3, 3, 4, 2); # a batch of 2 images, each having 4 channels
292+
293+ julia> y = NNlib.groupnorm(xs, 4);
294+
295+ julia> isapprox(std(y[:, :, 1:2, 1]), 1; atol = 0.1) &&
296+ std(xs[:, :, 1:2, 1]) != std(y[:, :, 1:2, 1])
297+ true
298+
299+ julia> isapprox(std(y[:, :, 3:4, 2]), 1; atol = 0.1) &&
300+ std(xs[:, :, 3:4, 2]) != std(y[:, :, 3:4, 2])
301+ true
302+ ```
303+ """
304+ function groupnorm (x:: AbstractArray{<:Any, N} , groups:: Integer ,
305+ scale:: Union{AbstractVector, Nothing} = nothing ,
306+ bias:: Union{AbstractVector, Nothing} = nothing ,
307+ ϵ = ofeltype (x, 1e-5 )) where {N}
308+ sz = size (x)
309+ channels = @ignore_derivatives begin
310+ ch = sz[max (1 , N - 1 )]
311+ newch, remainder = divrem (ch, groups)
312+ remainder == 0 ? newch :
313+ throw (ArgumentError (" channels $ch should be multiple of groups $groups " ))
314+ end
315+ affine_size = (ntuple (_ -> 1 , N - 2 )... , channels, groups, :)
316+ grouped_size = (sz[1 : (N - 2 )]. .. , channels, groups, :)
317+ x′ = reshape (x, grouped_size)
318+ μ, σ² = norm_stats (x′, ((1 : (N - 2 )). .. ,))
319+ return reshape (norm_helper (x′, μ, σ², scale, bias, ϵ, affine_size), sz)
320+ end
0 commit comments