Skip to content

Commit 84bd5a1

Browse files
author
Miha Zgubic
committed
add test config from JuliaDiff/ChainRules.jl#441
1 parent b2d0771 commit 84bd5a1

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/ChainRulesTestUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ include("iterator.jl")
3636
include("output_control.jl")
3737
include("check_result.jl")
3838

39+
include("rule_config.jl")
3940
include("finite_difference_calls.jl")
4041
include("testers.jl")
4142

src/rule_config.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# For testing this config re-dispatches Xrule_via_ad to Xrule without config argument
2+
struct ADviaRuleConfig <: RuleConfig{Union{HasReverseMode, HasForwardsMode}} end
3+
4+
function ChainRulesCore.frule_via_ad(config::ADviaRuleConfig, ȧrgs, f, args...; kws...)
5+
ret = frule(config, ȧrgs, f, args...; kws...)
6+
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
7+
ret === nothing && throw(MethodError(frule, (ȧrgs, f, args...)))
8+
return ret
9+
end
10+
11+
function ChainRulesCore.rrule_via_ad(config::ADviaRuleConfig, f, args...; kws...)
12+
ret = rrule(config, f, args...; kws...)
13+
# we don't support actually doing AD: the rule has to exist. lets give helpfulish error
14+
ret === nothing && throw(MethodError(rrule, (f, args...)))
15+
return ret
16+
end

0 commit comments

Comments
 (0)