Skip to content

Commit 98e87c3

Browse files
committed
handle indexing of GPU arrays
1 parent a68baa5 commit 98e87c3

File tree

6 files changed

+67
-4
lines changed

6 files changed

+67
-4
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
33
version = "1.41.0"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -15,6 +16,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617

1718
[compat]
19+
Adapt = "3.4.0"
1820
ChainRulesCore = "1.15.3"
1921
ChainRulesTestUtils = "1.5"
2022
Compat = "3.42.0, 4"
@@ -28,7 +30,6 @@ StaticArrays = "1.2"
2830
julia = "1.6"
2931

3032
[extras]
31-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3233
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3334
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3435
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
@@ -38,4 +39,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3839
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3940

4041
[targets]
41-
test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
42+
test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]

src/ChainRules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
module ChainRules
22

3+
using Adapt: adapt
34
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
45
using ChainRulesCore
56
using Compat
67
using Distributed
8+
using GPUArraysCore: AbstractGPUArray
79
using IrrationalConstants: logtwo, logten
810
using LinearAlgebra
911
using LinearAlgebra.BLAS

src/rulesets/Base/base.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu
241241
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
242242
return y, map_pullback
243243
end
244+
245+
#####
246+
##### `task_local_storage`
247+
#####
248+
249+
# Called by `@allowscalar` from GPUArrays
250+
251+
ChainRules.@non_differentiable task_local_storage(key::Any)
252+
ChainRules.@non_differentiable task_local_storage(key::Any, value::Any)
253+
254+
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value)
255+
y, back = task_local_storage(key, value) do
256+
rrule_via_ad(config, body)
257+
end
258+
function task_local_storage_pullback(dy)
259+
dbody = only(back(dy))
260+
return (NoTangent(), dbody, NoTangent(), NoTangent())
261+
end
262+
return y, task_local_storage_pullback
263+
end

src/rulesets/Base/indexing.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...)
113113
end
114114
function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...)
115115
view(dx, inds...) .+= dy
116-
# For GPU arrays, `inds::Union{Integer, Base.Slice}...` is fine, but any other AbstractArray risks overwriting.
117-
# Those should call `NNlib.scatter!`, alla https://github.com/FluxML/Zygote.jl/pull/1131
118116
return dx
119117
end
120118

@@ -134,6 +132,25 @@ function rrule(::typeof(∇getindex), x, dy, inds...)
134132
return z, ∇getindex_pullback
135133
end
136134

135+
# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
136+
# To avoid this, copy everything back to the CPU.
137+
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:
138+
139+
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Integer...)
140+
view(dx, inds...) .+= Ref(dy)
141+
return dx
142+
end
143+
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...)
144+
view(dx, inds...) .+= dy
145+
return dx
146+
end
147+
function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds...)
148+
dx_cpu = adapt(Array, dx)
149+
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
150+
copyto!(dx, dx_cpu)
151+
return dx
152+
end
153+
137154
#####
138155
##### first, tail
139156
#####

test/rulesets/Base/indexing.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,25 @@
143143
test_rrule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3; check_inferred=false)
144144
test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false)
145145
end
146+
147+
@testset "GPU" begin
148+
x_23_gpu = jl(rand(2, 3))
149+
150+
# Scalar indexing, copied from: @macroexpand @allowscalar A[i]
151+
# Gives an error in Pkg.test, no idea why
152+
# y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed)
153+
# @test y1 == @allowscalar x_gpu[1]
154+
# bk1(1.0) # This is zero, because finite-differencing ignores the function
155+
# ... but this works, and calls the rule:
156+
# Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1]
157+
158+
y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+=
159+
@test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1])
160+
161+
y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU
162+
@test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why
163+
@test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0])
164+
end
146165
end
147166

148167
@testset "first & tail" begin
@@ -178,6 +197,7 @@ end
178197
end
179198

180199
@testset "unsafe_getindex" begin
200+
# In real life this is called only on some AbstractRanges, but easier to test on Array:
181201
test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3)
182202
test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3)
183203
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ using Test, ChainRulesCore, ChainRulesTestUtils
22

33
@nospecialize
44

5+
using Adapt
56
using Base.Broadcast: broadcastable
67
using ChainRules
78
using ChainRulesCore
89
using ChainRulesTestUtils
910
using ChainRulesTestUtils: rand_tangent, _fdm
1011
using FiniteDifferences
12+
using GPUArraysCore
13+
using JLArrays
1114
using LinearAlgebra
1215
using LinearAlgebra.BLAS
1316
using LinearAlgebra: dot

0 commit comments

Comments
 (0)