Skip to content

Commit 187ed27

Browse files
committed
update, tidy
1 parent d16a90c commit 187ed27

File tree

3 files changed

+32
-51
lines changed

3 files changed

+32
-51
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -188,48 +188,49 @@ end
188188
rrule(::typeof(cumsum), x::AbstractVector) = rrule(cumsum, x; dims=1)
189189

190190
#####
191-
##### `maximum`, `minimum`
191+
##### `maximum(f, xs)`, `minimum(f, xs)`
192192
#####
193193

194+
# Rules for `maximum(x)` live with `findmax(x)` in array.jl
195+
194196
for mimum in (:minimum, :maximum)
195-
pullback1 = Symbol(mimum, :_pullback_f)
196-
pullback2 = Symbol(mimum, :_pullback_composed)
197197
findm = Symbol(:find, string(mimum)[1:3])
198198

199199
@eval function rrule(
200-
config::RuleConfig{>:HasReverseMode}, ::typeof($mimum), f::F, xs::AbstractArray{<:Number}; dims=:
200+
config::RuleConfig{>:HasReverseMode},
201+
::typeof($mimum),
202+
f::F,
203+
xs::AbstractArray{<:Number};
204+
dims = :,
201205
) where {F}
202206
project = ProjectTo(xs)
203-
204-
# The easy case is when we can use `findmax` to get index, and write into it:
205-
if dims isa Colon && VERSION >= v"1.7-"
207+
if dims isa Colon && VERSION >= v"1.7"
208+
# The easy case is when we can use `findmax` to get index, and write into it:
206209
y, ind = $findm(f, xs)
207-
function $pullback1(dy)
208-
# Notice this evaluates `f` one more time, but this shouldn't matter
209-
# unless `f` is sateful, in which case both this and `maximum(f.(xs))`
210-
# give undefined results.
210+
function minormax_f_back1(dy)
211+
# Notice this evaluates `f` one more time, but this shouldn't matter unless `f` is
212+
# sateful, in which case both this and `maximum(f.(xs))` give uncertain results.
211213
_, one_back = rrule_via_ad(config, f, xs[ind])
212-
df, one_dx_raw = one_back(unthunk(dy))
213-
one_dx = unthunk(one_dx_raw)
214-
x_thunk = @thunk project(_zerolike_writeat(xs, one_dx, dims, ind))
214+
df, one_dx = one_back(unthunk(dy))
215+
x_thunk = @thunk project(_zerolike_writeat(xs, unthunk(one_dx), dims, ind))
215216
x_ithunk = InplaceableThunk(x_thunk) do dxs
216-
view(dxs, ind) .+= one_dx
217+
view(dxs, ind) .+= unthunk(one_dx) # TODO make _zerolike_writeat handle thunks
217218
dxs
218219
end
219220
return (NoTangent(), df, x_ithunk)
220221
end
221-
return y, $pullback1
222+
return y, minormax_f_back1
222223

223-
# Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
224224
else
225-
mid, cast_back = rrule_via_ad(config, broadcast, f, xs; dims=dims)
226-
y, max_back = rrule($mimum, fxs; dims=dims)
227-
function $pullback2(dys)
228-
_, dmid = max_back(dys)
229-
_, df, dxs = cast_back(dmid) # if cast_back from rrule_via_ad makes an InplaceableThunk,
230-
return (NoTangent(), df, project(dxs)) # then this project() will give an error.
225+
# Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
226+
fxs, cast_back = rrule_via_ad(config, broadcast, f, xs)
227+
y, mm_back = rrule($mimum, fxs; dims)
228+
function minormax_f_back2(dy)
229+
_, dmid = mm_back(dy)
230+
_, df, dxs = cast_back(dmid)
231+
return (NoTangent(), df, project(dxs))
231232
end
232-
return y, $pullback2
233+
return y, minormax_f_back2
233234
end
234235

235236
end # @eval function rrule(...)

test/rulesets/Base/mapreduce.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,21 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
120120
@test rrule(SumRuleConfig(), Base.sum, xs, weights) isa Nothing
121121
end
122122

123-
@testset "maximum(f, xs)" begin
124-
# This calls back into AD
123+
VERSION >= v"1.7" && @testset "maximum(f, xs)" begin
125124
test_rrule(maximum, abs, [-4.0, 2.0, 2.0])
126125
test_rrule(minimum, sqrt, Float64[1 2; 3 4])
127126
test_rrule(maximum, Multiplier(2.0), [2.0, 4.0, 8.0]) # Multiplier defined in test_helpers.jl
128127

129-
# dims keyword -- these need to call `rrule_via_ad(broadcast, ...`
130-
@test_skip test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(;dims=1), check_inferred=false)
131-
@test_skip test_rrule(minimum, abs, randn(3,3), fkwargs=(;dims=2), check_inferred=false)
132-
133128
# repeated -- can't use FiniteDifferences
134-
y1, bk1 = rrule(TestConfigReverse(), maximum, abs, [-4.0, 2.0, 4.0, 2.0]) # TestConfigReverse defined in test_helpers.jl
129+
y1, bk1 = rrule(CFG, maximum, abs, [-4.0, 2.0, 4.0, 2.0])
135130
@test y1 === 4.0
136131
@test unthunk(bk1(10.0)[3]) == [-10, 0, 0, 0]
137132

138-
# y2, bk2 = rrule(TestConfigReverse(), minimum, abs, [1 2 3; -5 -4 -4], dims=2)
133+
# dims keyword -- these need to call `rrule_via_ad(broadcast, ...`, which needs AD
134+
@test_skip test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false)
135+
@test_skip test_rrule(minimum, abs, randn(3,3), fkwargs=(; dims = 2), check_inferred=false)
136+
137+
@test_skip y2, bk2 = rrule(CFG, minimum, abs, [1 2 3; -5 -4 -4], dims = 2)
139138
# @test y2 == hcat([1, 4])
140139
# @test unthunk(bk2(hcat([10, 20]))[3]) == [10 0 0; 0 -20 0]
141140
end

test/test_helpers.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,6 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x)
7575
return make_two_vec(x), make_two_vec_pullback
7676
end
7777

78-
# Trivial rule configurations, allowing `rrule_via_ad` with simple functions:
79-
struct TestConfigReverse <: RuleConfig{HasReverseMode} end
80-
function ChainRulesCore.rrule_via_ad(::TestConfigReverse, args...; kw...)
81-
if hasmethod(rrule, typeof(args), keys(kw))
82-
rrule(args...; kw...)
83-
else
84-
error("TestConfigReverse can only handle `rrule_via_ad(f, args...)` when there is an rrule method")
85-
end
86-
end
87-
88-
struct TestConfigForwards <: RuleConfig{HasForwardsMode} end
89-
function ChainRulesCore.frule_via_ad(::TestConfigReverse, args...; kw...)
90-
if hasmethod(frule, typeof(args), keys(kw))
91-
frule(args...; kw...)
92-
else
93-
error("TestConfigForwards can only handle `frule_via_ad(f, args...)` when there is an frule method")
94-
end
95-
end
96-
9778
@testset "test_helpers.jl" begin
9879

9980
@testset "Multiplier" begin

0 commit comments

Comments
 (0)