|
1 | | -"For testing this config re-dispatches Xrule_via_ad to Xrule without config argument" |
2 | | -struct ADviaRuleConfig <: RuleConfig{Union{HasReverseMode, HasForwardsMode}} end |
3 | | - |
4 | | -function ChainRulesCore.frule_via_ad(config::ADviaRuleConfig, ȧrgs, f, args...; kws...) |
5 | | - ret = frule(ȧrgs, f, args...; kws...) |
6 | | - # we don't support actually doing AD: the rule has to exist. lets give helpfulish error |
7 | | - ret === nothing && throw(MethodError(frule, (ȧrgs, f, args...))) |
8 | | - return ret |
9 | | -end |
10 | | - |
11 | | -function ChainRulesCore.rrule_via_ad(config::ADviaRuleConfig, f, args...; kws...) |
12 | | - ret = rrule(f, args...; kws...) |
13 | | - # we don't support actually doing AD: the rule has to exist. lets give helpfulish error |
14 | | - ret === nothing && throw(MethodError(rrule, (f, args...))) |
15 | | - return ret |
16 | | -end |
17 | | - |
18 | | -# A functor for testing |
19 | | -struct Multiplier{T} |
20 | | - x::T |
21 | | -end |
22 | | -(m::Multiplier)(y) = m.x * y |
23 | | -function ChainRulesCore.rrule(m::Multiplier, y) |
24 | | - Multiplier_pullback(z̄) = Tangent{typeof(m)}(; x=y * z̄), m.x * z̄ |
25 | | - return m(y), Multiplier_pullback |
26 | | -end |
27 | | - |
28 | 1 | @testset "Maps and Reductions" begin |
29 | 2 | @testset "sum" begin |
30 | 3 | sizes = (3, 4, 7) |
|
48 | 21 | end |
49 | 22 | end # sum abs2 |
50 | 23 |
|
51 | | - @testset "sum f" begin |
| 24 | + @testset "sum(f, xs)" begin |
52 | 25 | # This calls back into AD |
53 | | - # TODO: we don't have a easy way to test this via ChainRulesTestUtils |
| 26 | + test_rrule(sum, abs, [-4.0, 2.0, 2.0]) |
| 27 | + test_rrule(sum, Multiplier(2.0), [2.0, 4.0, 8.0]) |
54 | 28 |
|
55 | | - _, pb = rrule(ADviaRuleConfig(), sum, abs, [-4.0, 2.0, 2.0]) |
56 | | - @test pb(1.0) == (NoTangent(), NoTangent(), [-1.0, 1.0, 1.0]) |
| 29 | + test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]) # array of arrays |
57 | 30 |
|
58 | | - _, pb2 = rrule(ADviaRuleConfig(), sum, Multiplier(2.0), [2.0, 4.0, 8.0]) |
59 | | - @test pb2(1.0) == ( |
60 | | - NoTangent(), |
61 | | - Tangent{Multiplier{Float64}}(;x=14.0), |
62 | | - [2.0, 2.0, 2.0] |
63 | | - ) |
| 31 | + test_rrule(sum, abs, [-4.0 2.0; 2.0 -1.0]) |
| 32 | + test_rrule(sum, abs, [-4.0 2.0; 2.0 -1.0]') |
64 | 33 | end |
65 | 34 |
|
66 | 35 | @testset "prod" begin |
|
0 commit comments