@@ -232,62 +232,51 @@ end
232232rrule (:: typeof (cumsum), x:: AbstractVector ) = rrule (cumsum, x; dims= 1 )
233233
234234# ####
235- # #### `maximum`, `minimum`
235+ # #### `maximum(f, xs) `, `minimum(f, xs) `
236236# ####
237237
238+ # Rules for `maximum(x)` live with `findmax(x)` in array.jl
239+
238240for mimum in (:minimum , :maximum )
239- pullback1 = Symbol (mimum, :_pullback_f )
240- pullback2 = Symbol (mimum, :_pullback_composed )
241241 findm = Symbol (:find , string (mimum)[1 : 3 ])
242242
243243 @eval function rrule (
244- config:: RuleConfig{>:HasReverseMode} , :: typeof ($ mimum), f:: F , xs:: AbstractArray{<:Number} ; dims= :
244+ config:: RuleConfig{>:HasReverseMode} ,
245+ :: typeof ($ mimum),
246+ f:: F ,
247+ xs:: AbstractArray{<:Number} ;
248+ dims= :,
245249 ) where {F}
246250 project = ProjectTo (xs)
247-
248- # The easy case is when we can use `findmax` to get index, and write into it:
249- if dims isa Colon && VERSION >= v " 1.7-"
250- y, ind = $ findm (f, xs)
251- function $pullback1 (dy)
252- # Notice this evaluates `f` one more time, but this shouldn't matter
253- # unless `f` is sateful, in which case both this and `maximum(f.(xs))`
254- # give undefined results.
255- _, one_back = rrule_via_ad (config, f, xs[ind])
256- df, one_dx_raw = one_back (unthunk (dy))
257- one_dx = unthunk (one_dx_raw)
258- x_thunk = @thunk project (_writezero (xs, one_dx, ind, dims))
259- x_ithunk = InplaceableThunk (x_thunk) do dxs
260- view (dxs, ind) .+ = one_dx
261- dxs
262- end
263- return (NoTangent (), df, x_ithunk)
251+ if dims isa Colon && VERSION >= v " 1.7"
252+ # The fast case is when we can use `findmax` to get index, and write into it:
253+ y1, ind = $ findm (f, xs) # (Julia 1.6 doesn't have this method.)
254+ function minormax_f_back1 (dy)
255+ # Notice this evaluates `f` one more time, but this shouldn't matter unless `f` is
256+ # stateful, in which case both this and `maximum(f.(xs))` give uncertain results.
257+ y_ad, one_back = rrule_via_ad (config, f, xs[ind])
258+ isapprox (y_ad, y1) || throw (ArgumentError (" expected `f` to give same result with AD, got $y_ad != $y1 " ))
259+ df, one_dx = one_back (unthunk (dy))
260+ dxs = _zerolike_writeat (xs, unthunk (one_dx), dims, ind) # TODO make _zerolike_writeat handle thunks
261+ return (NoTangent (), df, project (dxs))
264262 end
265- return y, $ pullback1
263+ return y1, minormax_f_back1
266264
267- # Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
268265 else
269- mid, cast_back = rrule_via_ad (config, broadcast, f, xs; dims= dims)
270- y, max_back = rrule ($ mimum, fxs; dims= dims)
271- function $pullback2 (dys)
272- _, dmid = max_back (dys)
273- _, df, dxs = cast_back (dmid) # if cast_back from rrule_via_ad makes an InplaceableThunk,
274- return (NoTangent (), df, project (dxs)) # then this project() will give an error.
266+ # Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
267+ fxs, cast_back = rrule_via_ad (config, broadcast, f, xs)
268+ y2, mm_back = rrule ($ mimum, fxs; dims)
269+ function minormax_f_back2 (dy)
270+ _, dmid = mm_back (dy)
271+ _, df, dxs = cast_back (dmid)
272+ return (NoTangent (), df, project (dxs))
275273 end
276- return y, $ pullback2
277- end
274+ return y2, minormax_f_back2
278275
276+ end
279277 end # @eval function rrule(...)
280278end
281279
282- # from another PR:
283- function _writezero (x, dy, ind, dims)
284- # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
285- # allow `eltype(dy)`, nor does it work for many structured matrices.
286- dx = fill! (similar (x, eltype (dy), axes (x)), false )
287- view (dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
288- dx
289- end
290-
291280# ####
292281# #### `prod`
293282# ####
0 commit comments