@@ -40,30 +40,127 @@ function rrule(
4040end
4141
4242function rrule (
43- config:: RuleConfig{>:HasReverseMode} , :: typeof (sum), f, xs:: AbstractArray ; dims= :
44- )
45- fx_and_pullbacks = map (x-> rrule_via_ad (config, f, x), xs)
46- y = sum (first, fx_and_pullbacks; dims= dims)
43+ config:: RuleConfig{>:HasReverseMode} , :: typeof (sum), f:: F , xs:: AbstractArray ; dims= :
44+ ) where {F}
45+ # fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
46+ # y = sum(first, fx_and_pullbacks; dims=dims)
47+ # pullbacks = last.(fx_and_pullbacks)
4748
48- pullbacks = last .(fx_and_pullbacks )
49+ y = sum (f, xs; dims = dims )
4950
5051 project = ProjectTo (xs)
5152
52- function sum_pullback (ȳ)
53- call (f, x) = f (x) # we need to broadcast this to handle dims kwarg
54- f̄_and_x̄s = call .(pullbacks, ȳ)
55- # no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
56- f̄ = if fieldcount (typeof (f)) === 0 # Then don't need to worry about derivative wrt f
57- NoTangent ()
53+ # function sum_pullback(ȳ)
54+ # call(b, x) = b(x) # we need to broadcast this to handle dims kwarg
55+ # f̄_and_x̄s = call.(pullbacks, ȳ)
56+ # # no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
57+ # f̄ = if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
58+ # NoTangent()
59+ # else
60+ # sum(first, f̄_and_x̄s)
61+ # end
62+ # x̄s = map(unthunk ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
63+ # return NoTangent(), f̄, project(x̄s)
64+ # end
65+
66+ function sum_pullback_f (dys_raw)
67+ dys = unthunk (dys_raw)
68+ if Base. issingletontype (F)
69+ dfs = NoTangent ()
70+ # Best case. We evaluate `f` a second time, then immediately its pullback, so that
71+ # we don't need to save an array of these. The point of `sum(f, xs)` is memory-saving.
72+ dxs = broadcast (xs, dys) do x, dy
73+ _, bk = rrule_via_ad (config,f,x)
74+ df, dx = bk (dy)
75+ unthunk (dx)
76+ end
5877 else
59- sum (first, f̄_and_x̄s)
78+ # We need to accumulate the gradient with respect to `f`. To avoid making a big array
79+ # and unzipping, can we accumulate from inside the broadcast needed for `dxs`?
80+ _, bk1 = rrule_via_ad (config,f,first (xs))
81+ dy1 = dims isa Colon ? dys : first (dys)
82+ df_ref = Ref (zero (bk1 (dy1)))
83+ dxs = broadcast (xs, dys) do x, dy
84+ _, bk = rrule_via_ad (config,f,x)
85+ df, dx = bk (dy)
86+ df_ref[] += df # not sure this will work!
87+ unthunk (dx)
88+ end
89+ dfs = df_ref[] # s for sum, not plural!
6090 end
61- x̄s = map (unthunk ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
62- return NoTangent (), f̄, project (x̄s)
91+ return (NoTangent (), dfs, project (dxs))
6392 end
64- return y, sum_pullback
93+ return y, sum_pullback_f
6594end
6695
96+ #=
97+
98+ # Timing this change.
99+
100+
101+ julia> @btime sum(sqrt, x) setup=(x=$(rand(30,30)));
102+ 833.333 ns (0 allocations: 0 bytes)
103+
104+ julia> @btime gradient(x -> sum(sqrt, x), $(rand(30,30)));
105+ 4.173 μs (16 allocations: 50.02 KiB) # before
106+ 1.954 μs (2 allocations: 7.20 KiB) # after
107+
108+ julia> @btime sum(sqrt, x, dims=1) setup=(x=$(rand(30,30)));
109+ 772.174 ns (1 allocation: 336 bytes)
110+
111+ julia> @btime gradient(x -> sum(sum(sqrt, x, dims=1)), $(rand(30,30)));
112+ 10.625 μs (42 allocations: 51.47 KiB) # before
113+ 2.704 μs (18 allocations: 8.20 KiB) # after
114+
115+ # compare broadcasting:
116+
117+ julia> @btime sum(sqrt.(x)) setup=(x=$(rand(30,30)));
118+ 873.544 ns (1 allocation: 7.19 KiB)
119+
120+ julia> @btime gradient(x -> sum(sqrt.(x)), $(rand(30,30)));
121+ 2.616 μs (10 allocations: 28.70 KiB)
122+
123+ julia> @btime sum(sqrt.(x), dims=1) setup=(x=$(rand(30,30)));
124+ 953.667 ns (2 allocations: 7.52 KiB)
125+
126+ julia> @btime gradient(x -> sum(sum(sqrt.(x), dims=1)), $(rand(30,30)));
127+ 3.542 μs (26 allocations: 36.81 KiB)
128+
129+
130+ # Bigger example, slower function?
131+
132+
133+ julia> @btime sum(log∘exp, x) setup=(x=$(rand(300,300)));
134+ 1.348 ms (0 allocations: 0 bytes)
135+
136+ julia> @btime gradient(x -> sum(log∘exp, x), $(rand(300,300)));
137+ 930.421 ms (8460059 allocations: 206.68 MiB) # before
138+ 873.300 ms (7740033 allocations: 187.46 MiB) # after
139+
140+ julia> @btime sum(log∘exp, x, dims=1) setup=(x=$(rand(300,300)));
141+ 1.349 ms (1 allocation: 2.50 KiB)
142+
143+ julia> @btime gradient(x -> sum(sum(log∘exp, x, dims=1)), $(rand(300,300)));
144+ 935.160 ms (8460088 allocations: 206.69 MiB) # before
145+ 890.860 ms (7740037 allocations: 187.46 MiB) # after
146+
147+ # compare broadcasting:
148+
149+ julia> @btime sum((log∘exp).(x)) setup=(x=$(rand(300,300)));
150+ 1.342 ms (2 allocations: 703.20 KiB)
151+
152+ julia> @btime gradient(x -> sum((log∘exp).(x)), $(rand(300,300)));
153+ 1.449 ms (27 allocations: 2.75 MiB)
154+
155+ julia> @btime sum((log∘exp).(x), dims=1) setup=(x=$(rand(300,300)));
156+ 1.380 ms (3 allocations: 705.70 KiB)
157+
158+ julia> @btime gradient(x -> sum(sum((log∘exp).(x), dims=1)), $(rand(300,300)));
159+ 1.490 ms (16 allocations: 3.44 MiB)
160+
161+ =#
162+
163+
67164function frule (
68165 (_, _, Δx),
69166 :: typeof (sum),
0 commit comments