Skip to content

Commit 75d7762

Browse files
committed
use ChainRulesTestUtils to test
1 parent 2e9656e commit 75d7762

File tree

1 file changed

+6
-37
lines changed

1 file changed

+6
-37
lines changed

test/rulesets/Base/mapreduce.jl

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,3 @@
1-
"For testing this config re-dispatches Xrule_via_ad to Xrule without config argument"
2-
struct ADviaRuleConfig <: RuleConfig{Union{HasReverseMode, HasForwardsMode}} end
3-
4-
function ChainRulesCore.frule_via_ad(config::ADviaRuleConfig, ȧrgs, f, args...; kws...)
5-
ret = frule(ȧrgs, f, args...; kws...)
6-
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
7-
ret === nothing && throw(MethodError(frule, (ȧrgs, f, args...)))
8-
return ret
9-
end
10-
11-
function ChainRulesCore.rrule_via_ad(config::ADviaRuleConfig, f, args...; kws...)
12-
ret = rrule(f, args...; kws...)
13-
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
14-
ret === nothing && throw(MethodError(rrule, (f, args...)))
15-
return ret
16-
end
17-
18-
# A functor for testing
19-
struct Multiplier{T}
20-
x::T
21-
end
22-
(m::Multiplier)(y) = m.x * y
23-
function ChainRulesCore.rrule(m::Multiplier, y)
24-
Multiplier_pullback(z̄) = Tangent{typeof(m)}(; x=y * z̄), m.x *
25-
return m(y), Multiplier_pullback
26-
end
27-
281
@testset "Maps and Reductions" begin
292
@testset "sum" begin
303
sizes = (3, 4, 7)
@@ -48,19 +21,15 @@ end
4821
end
4922
end # sum abs2
5023

51-
@testset "sum f" begin
24+
@testset "sum(f, xs)" begin
5225
# This calls back into AD
53-
# TODO: we don't have a easy way to test this via ChainRulesTestUtils
26+
test_rrule(sum, abs, [-4.0, 2.0, 2.0])
27+
test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0])
5428

55-
_, pb = rrule(ADviaRuleConfig(), sum, abs, [-4.0, 2.0, 2.0])
56-
@test pb(1.0) == (NoTangent(), NoTangent(), [-1.0, 1.0, 1.0])
29+
test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]) # array of arrays
5730

58-
_, pb2 = rrule(ADviaRuleConfig(), sum, Multiplier(2.0), [2.0, 4.0, 8.0])
59-
@test pb2(1.0) == (
60-
NoTangent(),
61-
Tangent{Multiplier{Float64}}(;x=14.0),
62-
[2.0, 2.0, 2.0]
63-
)
31+
test_rrule(sum, abs, [-4.0 2.0; 2.0 -1.0])
32+
test_rrule(sum, abs, [-4.0 2.0; 2.0 -1.0]')
6433
end
6534

6635
@testset "prod" begin

0 commit comments

Comments
 (0)