Skip to content

Commit 30532f6

Browse files
committed
update
fixup update, tidy Apply 3 suggestions Co-authored-by: Miha Zgubic <[email protected]> add an error remove error, as closing over `y` breaks inference simplify, update solve Core.Box tests approx
1 parent 9ab580f commit 30532f6

File tree

4 files changed

+48
-77
lines changed

4 files changed

+48
-77
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414

1515
[compat]
1616
ChainRulesCore = "1.12"
17-
ChainRulesTestUtils = "1.5"
18-
Compat = "3.42.0"
19-
FiniteDifferences = "0.12.20"
17+
ChainRulesTestUtils = "1.6"
18+
Compat = "3.42"
19+
FiniteDifferences = "0.12.24"
2020
IrrationalConstants = "0.1.1"
2121
JuliaInterpreter = "0.8,0.9"
2222
RealDot = "0.1"
@@ -33,3 +33,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3333

3434
[targets]
3535
test = ["ChainRulesTestUtils", "FiniteDifferences", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
36+

src/rulesets/Base/mapreduce.jl

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -232,62 +232,51 @@ end
232232
rrule(::typeof(cumsum), x::AbstractVector) = rrule(cumsum, x; dims=1)
233233

234234
#####
235-
##### `maximum`, `minimum`
235+
##### `maximum(f, xs)`, `minimum(f, xs)`
236236
#####
237237

238+
# Rules for `maximum(x)` live with `findmax(x)` in array.jl
239+
238240
for mimum in (:minimum, :maximum)
239-
pullback1 = Symbol(mimum, :_pullback_f)
240-
pullback2 = Symbol(mimum, :_pullback_composed)
241241
findm = Symbol(:find, string(mimum)[1:3])
242242

243243
@eval function rrule(
244-
config::RuleConfig{>:HasReverseMode}, ::typeof($mimum), f::F, xs::AbstractArray{<:Number}; dims=:
244+
config::RuleConfig{>:HasReverseMode},
245+
::typeof($mimum),
246+
f::F,
247+
xs::AbstractArray{<:Number};
248+
dims=:,
245249
) where {F}
246250
project = ProjectTo(xs)
247-
248-
# The easy case is when we can use `findmax` to get index, and write into it:
249-
if dims isa Colon && VERSION >= v"1.7-"
250-
y, ind = $findm(f, xs)
251-
function $pullback1(dy)
252-
# Notice this evaluates `f` one more time, but this shouldn't matter
253-
# unless `f` is sateful, in which case both this and `maximum(f.(xs))`
254-
# give undefined results.
255-
_, one_back = rrule_via_ad(config, f, xs[ind])
256-
df, one_dx_raw = one_back(unthunk(dy))
257-
one_dx = unthunk(one_dx_raw)
258-
x_thunk = @thunk project(_writezero(xs, one_dx, ind, dims))
259-
x_ithunk = InplaceableThunk(x_thunk) do dxs
260-
view(dxs, ind) .+= one_dx
261-
dxs
262-
end
263-
return (NoTangent(), df, x_ithunk)
251+
if dims isa Colon && VERSION >= v"1.7"
252+
# The fast case is when we can use `findmax` to get index, and write into it:
253+
y1, ind = $findm(f, xs) # (Julia 1.6 doesn't have this method.)
254+
function minormax_f_back1(dy)
255+
# Notice this evaluates `f` one more time, but this shouldn't matter unless `f` is
256+
# stateful, in which case both this and `maximum(f.(xs))` give uncertain results.
257+
y_ad, one_back = rrule_via_ad(config, f, xs[ind])
258+
isapprox(y_ad, y1) || throw(ArgumentError("expected `f` to give same result with AD, got $y_ad != $y1"))
259+
df, one_dx = one_back(unthunk(dy))
260+
dxs = _zerolike_writeat(xs, unthunk(one_dx), dims, ind) # TODO make _zerolike_writeat handle thunks
261+
return (NoTangent(), df, project(dxs))
264262
end
265-
return y, $pullback1
263+
return y1, minormax_f_back1
266264

267-
# Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
268265
else
269-
mid, cast_back = rrule_via_ad(config, broadcast, f, xs; dims=dims)
270-
y, max_back = rrule($mimum, fxs; dims=dims)
271-
function $pullback2(dys)
272-
_, dmid = max_back(dys)
273-
_, df, dxs = cast_back(dmid) # if cast_back from rrule_via_ad makes an InplaceableThunk,
274-
return (NoTangent(), df, project(dxs)) # then this project() will give an error.
266+
# Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
267+
fxs, cast_back = rrule_via_ad(config, broadcast, f, xs)
268+
y2, mm_back = rrule($mimum, fxs; dims)
269+
function minormax_f_back2(dy)
270+
_, dmid = mm_back(dy)
271+
_, df, dxs = cast_back(dmid)
272+
return (NoTangent(), df, project(dxs))
275273
end
276-
return y, $pullback2
277-
end
274+
return y2, minormax_f_back2
278275

276+
end
279277
end # @eval function rrule(...)
280278
end
281279

282-
# from another PR:
283-
function _writezero(x, dy, ind, dims)
284-
# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
285-
# allow `eltype(dy)`, nor does it work for many structured matrices.
286-
dx = fill!(similar(x, eltype(dy), axes(x)), false)
287-
view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
288-
dx
289-
end
290-
291280
#####
292281
##### `prod`
293282
#####

test/rulesets/Base/mapreduce.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights)
33
struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
44

5-
const CFG = ChainRulesTestUtils.ADviaRuleConfig()
6-
75
@testset "Reductions" begin
86
@testset "sum(::Tuple)" begin
97
test_frule(sum, Tuple(rand(5)))
@@ -137,23 +135,22 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
137135
end
138136

139137
@testset "maximum(f, xs)" begin
140-
# This calls back into AD
141-
test_rrule(maximum, abs, [-4.0, 2.0, 2.0], check_inferred=false)
142-
test_rrule(minimum, sqrt, Float64[1 2; 3 4], check_inferred=false)
143-
test_rrule(maximum, Multiplier(2.0), [2.0, 4.0, 8.0], check_inferred=false) # Multiplier defined in test_helpers.jl
144-
145-
# dims keyword
146-
@test_skip test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(;dims=1), check_inferred=false)
147-
@test_skip test_rrule(minimum, abs, randn(3,3), fkwargs=(;dims=2), check_inferred=false)
138+
test_rrule(maximum, abs, [-4.0, 2.0, 2.0])
139+
test_rrule(minimum, sqrt, Float64[1 2; 3 4])
140+
test_rrule(maximum, Multiplier(2.0), [2.0, 4.0, 8.0]) # Multiplier defined in test_helpers.jl
148141

149142
# repeated -- can't use FiniteDifferences
150-
y1, bk1 = rrule(TestConfigReverse(), maximum, abs, [-4.0, 2.0, 4.0, 2.0]) # TestConfigReverse defined in test_helpers.jl
143+
y1, bk1 = rrule(CFG, maximum, abs, [-4.0, 2.0, 4.0, 2.0]) # CFG defined in test_helpers.jl
151144
@test y1 === 4.0
152-
@test unthunk(bk1(10.0)[3]) == [-10, 0, 0, 0]
145+
@test unthunk(bk1(10.0)[3]) [-10, 0, 0, 0]
146+
147+
# dims keyword -- these call `rrule_via_ad(broadcast, ...`
148+
test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false)
149+
test_rrule(minimum, abs, randn(3,3), fkwargs=(; dims = 2), check_inferred=false)
153150

154-
# y2, bk2 = rrule(TestConfigReverse(), minimum, abs, [1 2 3; -5 -4 -4], dims=2)
155-
# @test y2 == hcat([1, 4])
156-
# @test unthunk(bk2(hcat([10, 20]))[3]) == [10 0 0; 0 -20 0]
151+
y2, bk2 = rrule(CFG, minimum, abs, [1 2 3; -5 -4 -4], dims = 2)
152+
@test y2 == hcat([1, 4])
153+
@test_broken unthunk(bk2(hcat([10, 20]))[3]) [10 0 0; 0 -20 0] # This used to work? Fine in Zygote
157154
end
158155

159156
@testset "prod" begin

test/test_helpers.jl

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
const CFG = ChainRulesTestUtils.TestConfig() # CRTU v1.6
3+
14
"""
25
Multiplier(x)
36
@@ -97,25 +100,6 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x)
97100
return make_two_vec(x), make_two_vec_pullback
98101
end
99102

100-
# Trivial rule configurations, allowing `rrule_via_ad` with simple functions:
101-
struct TestConfigReverse <: RuleConfig{HasReverseMode} end
102-
function ChainRulesCore.rrule_via_ad(::TestConfigReverse, args...; kw...)
103-
if hasmethod(rrule, typeof(args), keys(kw))
104-
rrule(args...; kw...)
105-
else
106-
error("TestConfigReverse can only handle `rrule_via_ad(f, args...)` when there is an rrule method")
107-
end
108-
end
109-
110-
struct TestConfigForwards <: RuleConfig{HasForwardsMode} end
111-
function ChainRulesCore.frule_via_ad(::TestConfigReverse, args...; kw...)
112-
if hasmethod(frule, typeof(args), keys(kw))
113-
frule(args...; kw...)
114-
else
115-
error("TestConfigForwards can only handle `frule_via_ad(f, args...)` when there is an frule method")
116-
end
117-
end
118-
119103
@testset "test_helpers.jl" begin
120104

121105
@testset "Multiplier" begin

0 commit comments

Comments
 (0)