Skip to content

Commit c8034da

Browse files
committed
save less stuff in sum(f, xs) rule
1 parent c66c4a9 commit c8034da

File tree

1 file changed

+112
-15
lines changed

1 file changed

+112
-15
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 112 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,127 @@ function rrule(
4040
end
4141

4242
function rrule(
43-
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray; dims=:
44-
)
45-
fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
46-
y = sum(first, fx_and_pullbacks; dims=dims)
43+
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f::F, xs::AbstractArray; dims=:
44+
) where {F}
45+
# fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
46+
# y = sum(first, fx_and_pullbacks; dims=dims)
47+
# pullbacks = last.(fx_and_pullbacks)
4748

48-
pullbacks = last.(fx_and_pullbacks)
49+
y = sum(f, xs; dims=dims)
4950

5051
project = ProjectTo(xs)
5152

52-
function sum_pullback(ȳ)
53-
call(f, x) = f(x) # we need to broadcast this to handle dims kwarg
54-
f̄_and_x̄s = call.(pullbacks, ȳ)
55-
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
56-
= if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
57-
NoTangent()
53+
# function sum_pullback(ȳ)
54+
# call(b, x) = b(x) # we need to broadcast this to handle dims kwarg
55+
# f̄_and_x̄s = call.(pullbacks, ȳ)
56+
# # no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
57+
# f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
58+
# NoTangent()
59+
# else
60+
# sum(first, f̄_and_x̄s)
61+
# end
62+
# x̄s = map(unthunk ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
63+
# return NoTangent(), f̄, project(x̄s)
64+
# end
65+
66+
function sum_pullback_f(dys_raw)
67+
dys = unthunk(dys_raw)
68+
if Base.issingletontype(F)
69+
dfs = NoTangent()
70+
# Best case. We evaluate `f` a second time, then immediately its pullback, so that
71+
# we don't need to save an array of these. The point of `sum(f, xs)` is memory-saving.
72+
dxs = broadcast(xs, dys) do x, dy
73+
_, bk = rrule_via_ad(config,f,x)
74+
df, dx = bk(dy)
75+
unthunk(dx)
76+
end
5877
else
59-
sum(first, f̄_and_x̄s)
78+
# We need to accumulate the gradient with respect to `f`. To avoid making a big array
79+
# and unzipping, can we accumulate from inside the broadcast needed for `dxs`?
80+
_, bk1 = rrule_via_ad(config,f,first(xs))
81+
dy1 = dims isa Colon ? dys : first(dys)
82+
df_ref = Ref(zero(bk1(dy1)))
83+
dxs = broadcast(xs, dys) do x, dy
84+
_, bk = rrule_via_ad(config,f,x)
85+
df, dx = bk(dy)
86+
df_ref[] += df # not sure this will work!
87+
unthunk(dx)
88+
end
89+
dfs = df_ref[] # s for sum, not plural!
6090
end
61-
x̄s = map(unthunk last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
62-
return NoTangent(), f̄, project(x̄s)
91+
return (NoTangent(), dfs, project(dxs))
6392
end
64-
return y, sum_pullback
93+
return y, sum_pullback_f
6594
end
6695

96+
#=
97+
98+
# Timing this change.
99+
100+
101+
julia> @btime sum(sqrt, x) setup=(x=$(rand(30,30)));
102+
833.333 ns (0 allocations: 0 bytes)
103+
104+
julia> @btime gradient(x -> sum(sqrt, x), $(rand(30,30)));
105+
4.173 μs (16 allocations: 50.02 KiB) # before
106+
1.954 μs (2 allocations: 7.20 KiB) # after
107+
108+
julia> @btime sum(sqrt, x, dims=1) setup=(x=$(rand(30,30)));
109+
772.174 ns (1 allocation: 336 bytes)
110+
111+
julia> @btime gradient(x -> sum(sum(sqrt, x, dims=1)), $(rand(30,30)));
112+
10.625 μs (42 allocations: 51.47 KiB) # before
113+
2.704 μs (18 allocations: 8.20 KiB) # after
114+
115+
# compare broadcasting:
116+
117+
julia> @btime sum(sqrt.(x)) setup=(x=$(rand(30,30)));
118+
873.544 ns (1 allocation: 7.19 KiB)
119+
120+
julia> @btime gradient(x -> sum(sqrt.(x)), $(rand(30,30)));
121+
2.616 μs (10 allocations: 28.70 KiB)
122+
123+
julia> @btime sum(sqrt.(x), dims=1) setup=(x=$(rand(30,30)));
124+
953.667 ns (2 allocations: 7.52 KiB)
125+
126+
julia> @btime gradient(x -> sum(sum(sqrt.(x), dims=1)), $(rand(30,30)));
127+
3.542 μs (26 allocations: 36.81 KiB)
128+
129+
130+
# Bigger example, slower function?
131+
132+
133+
julia> @btime sum(log∘exp, x) setup=(x=$(rand(300,300)));
134+
1.348 ms (0 allocations: 0 bytes)
135+
136+
julia> @btime gradient(x -> sum(log∘exp, x), $(rand(300,300)));
137+
930.421 ms (8460059 allocations: 206.68 MiB) # before
138+
873.300 ms (7740033 allocations: 187.46 MiB) # after
139+
140+
julia> @btime sum(log∘exp, x, dims=1) setup=(x=$(rand(300,300)));
141+
1.349 ms (1 allocation: 2.50 KiB)
142+
143+
julia> @btime gradient(x -> sum(sum(log∘exp, x, dims=1)), $(rand(300,300)));
144+
935.160 ms (8460088 allocations: 206.69 MiB) # before
145+
890.860 ms (7740037 allocations: 187.46 MiB) # after
146+
147+
# compare broadcasting:
148+
149+
julia> @btime sum((log∘exp).(x)) setup=(x=$(rand(300,300)));
150+
1.342 ms (2 allocations: 703.20 KiB)
151+
152+
julia> @btime gradient(x -> sum((log∘exp).(x)), $(rand(300,300)));
153+
1.449 ms (27 allocations: 2.75 MiB)
154+
155+
julia> @btime sum((log∘exp).(x), dims=1) setup=(x=$(rand(300,300)));
156+
1.380 ms (3 allocations: 705.70 KiB)
157+
158+
julia> @btime gradient(x -> sum(sum((log∘exp).(x), dims=1)), $(rand(300,300)));
159+
1.490 ms (16 allocations: 3.44 MiB)
160+
161+
=#
162+
163+
67164
function frule(
68165
(_, _, Δx),
69166
::typeof(sum),

0 commit comments

Comments
 (0)