@@ -188,48 +188,49 @@ end
188188rrule (:: typeof (cumsum), x:: AbstractVector ) = rrule (cumsum, x; dims= 1 )
189189
190190# ####
191- # #### `maximum`, `minimum`
191+ # #### `maximum(f, xs) `, `minimum(f, xs) `
192192# ####
193193
194+ # Rules for `maximum(x)` live with `findmax(x)` in array.jl
195+
194196for mimum in (:minimum , :maximum )
195- pullback1 = Symbol (mimum, :_pullback_f )
196- pullback2 = Symbol (mimum, :_pullback_composed )
197197 findm = Symbol (:find , string (mimum)[1 : 3 ])
198198
199199 @eval function rrule (
200- config:: RuleConfig{>:HasReverseMode} , :: typeof ($ mimum), f:: F , xs:: AbstractArray{<:Number} ; dims= :
200+ config:: RuleConfig{>:HasReverseMode} ,
201+ :: typeof ($ mimum),
202+ f:: F ,
203+ xs:: AbstractArray{<:Number} ;
204+ dims = :,
201205 ) where {F}
202206 project = ProjectTo (xs)
203-
204- # The easy case is when we can use `findmax` to get index, and write into it:
205- if dims isa Colon && VERSION >= v " 1.7-"
207+ if dims isa Colon && VERSION >= v " 1.7"
208+ # The easy case is when we can use `findmax` to get index, and write into it:
206209 y, ind = $ findm (f, xs)
207- function $pullback1 (dy)
208- # Notice this evaluates `f` one more time, but this shouldn't matter
209- # unless `f` is sateful, in which case both this and `maximum(f.(xs))`
210- # give undefined results.
210+ function minormax_f_back1 (dy)
211+ # Notice this evaluates `f` one more time, but this shouldn't matter unless `f` is
212+ # sateful, in which case both this and `maximum(f.(xs))` give uncertain results.
211213 _, one_back = rrule_via_ad (config, f, xs[ind])
212- df, one_dx_raw = one_back (unthunk (dy))
213- one_dx = unthunk (one_dx_raw)
214- x_thunk = @thunk project (_zerolike_writeat (xs, one_dx, dims, ind))
214+ df, one_dx = one_back (unthunk (dy))
215+ x_thunk = @thunk project (_zerolike_writeat (xs, unthunk (one_dx), dims, ind))
215216 x_ithunk = InplaceableThunk (x_thunk) do dxs
216- view (dxs, ind) .+ = one_dx
217+ view (dxs, ind) .+ = unthunk ( one_dx) # TODO make _zerolike_writeat handle thunks
217218 dxs
218219 end
219220 return (NoTangent (), df, x_ithunk)
220221 end
221- return y, $ pullback1
222+ return y, minormax_f_back1
222223
223- # Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
224224 else
225- mid, cast_back = rrule_via_ad (config, broadcast, f, xs; dims= dims)
226- y, max_back = rrule ($ mimum, fxs; dims= dims)
227- function $pullback2 (dys)
228- _, dmid = max_back (dys)
229- _, df, dxs = cast_back (dmid) # if cast_back from rrule_via_ad makes an InplaceableThunk,
230- return (NoTangent (), df, project (dxs)) # then this project() will give an error.
225+ # Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
226+ fxs, cast_back = rrule_via_ad (config, broadcast, f, xs)
227+ y, mm_back = rrule ($ mimum, fxs; dims)
228+ function minormax_f_back2 (dy)
229+ _, dmid = mm_back (dy)
230+ _, df, dxs = cast_back (dmid)
231+ return (NoTangent (), df, project (dxs))
231232 end
232- return y, $ pullback2
233+ return y, minormax_f_back2
233234 end
234235
235236 end # @eval function rrule(...)
0 commit comments