Skip to content

Commit 0d1d140

Browse files
committed
Fix dispatch for config-free sum
1 parent db40349 commit 0d1d140

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,17 @@ function rrule(
7777
return y, sum_abs2_pullback
7878
end
7979

80+
# Fix dispatch for this pidgeon-hole optimization,
81+
# Rules with RuleConfig dispatch with priority over without (regardless of other args).
82+
# and if we don't specify what do do for one that HasReverseMode then it is ambigious
83+
for Config in (RuleConfig, RuleConfig{>:HasReverseMode})
84+
@eval function rrule(
85+
::$Config, ::typeof(sum), ::typeof(abs2), x::AbstractArray{T}; dims=:,
86+
) where {T<:Union{Real,Complex}}
87+
return rrule(sum, abs2, x; dims=dims)
88+
end
89+
end
90+
8091
#####
8192
##### `prod`
8293
#####

0 commit comments

Comments
 (0)