Skip to content
This repository was archived by the owner on Sep 20, 2024. It is now read-only.

Commit 9b097ae

Browse files
committed
Support StructArray+SArray on the GPU
This adds the necessary functionality to run a `StructArray` with elements of type `SArray`. For this to work on the GPU we need to inline a couple of functions in StructArrays, see <JuliaArrays/StructArrays.jl#177>.
1 parent a258789 commit 9b097ae

File tree

6 files changed

+224
-0
lines changed

6 files changed

+224
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1010
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
1111
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
12+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1213
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1314
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1617
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1718
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
19+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1820
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1921
WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192"
2022

@@ -28,6 +30,7 @@ KernelAbstractions = "0.6"
2830
LazyArrays = "0.20, 0.21"
2931
LoopVectorization = "0.11, 0.12"
3032
StaticArrays = "0.12, 1"
33+
StructArrays = "0.5"
3134
Tullio = "0.2"
3235
WriteVTK = "1.9"
3336
julia = "1.6"

src/Bennu.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ using ArrayInterface
55
using CUDA
66
using CUDAKernels
77
using FillArrays
8+
using GPUArrays
89
using KernelAbstractions
910
using LazyArrays
1011
using LinearAlgebra
1112
using LoopVectorization
1213
using SparseArrays
1314
using StaticArrays
1415
using StaticArrays: tuple_prod, tuple_length, size_to_tuple
16+
using StructArrays
1517
using Tullio
1618
using WriteVTK
1719

@@ -37,6 +39,7 @@ include("operators.jl")
3739
include("partitions.jl")
3840
include("quadratures.jl")
3941
include("sparsearrays.jl")
42+
include("structarrays.jl")
4043
include("tuples.jl")
4144

4245
end

src/structarrays.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# These two similar functions can be removed once
2+
# <https://github.com/JuliaArrays/StructArrays.jl/pull/94>
3+
# is accepted.
4+
function Base.similar(s::StructArray{T,N,C}, ::Type{T}, sz::NTuple{M,Int64}) where {T,N,M,C<:Union{Tuple,NamedTuple}}
5+
return StructArray{T}(map(typ -> similar(typ, sz),
6+
StructArrays.components(s)))
7+
end
8+
9+
function Base.similar(s::StructArray{T,N,C}, S::Type, sz::NTuple{M,Int64}) where {T,N,M,C<:Union{Tuple,NamedTuple}}
10+
# If not specified, we don't really know what kind of array to use for each
11+
# interior type, so we just pick the first one arbitrarily. If users need
12+
# something else, they need to be more specific.
13+
f1 = StructArrays.components(s)[1]
14+
if isstructtype(S)
15+
return StructArrays.buildfromschema(typ -> similar(f1, typ, sz), S)
16+
else
17+
return similar(f1, S, sz)
18+
end
19+
end
20+
21+
# The following broadcast code is slightly modified from the code found at
22+
# <https://github.com/JuliaArrays/StructArrays.jl/issues/150>.
23+
const GPUStore = Tuple{Vararg{GPUArrays.BroadcastGPUArray}}
24+
const NamedGPUStore = NamedTuple{Name,<:GPUStore} where {Name}
25+
const StructGPUArray = StructArray{T,N,<:Union{GPUStore,NamedGPUStore}} where {T,N}
26+
## backend for StructArray
27+
GPUArrays.backend(A::StructGPUArray) =
28+
GPUArrays.backend(StructArrays.components(A))
29+
GPUArrays.backend(t::GPUStore) = GPUArrays.backend(typeof(t))
30+
GPUArrays.backend(nt::NamedGPUStore) =
31+
GPUArrays.backend(typeof(nt).parameters[2])
32+
function GPUArrays.backend(::Type{T}) where {T<:GPUStore}
33+
bs = GPUArrays.backend.(tuple(T.parameters...))
34+
I = all(map(isequal(first(bs)), bs))
35+
I || throw("device error")
36+
GPUArrays.backend(T.parameters[1])
37+
end
38+
39+
## copy from GPUArrays
40+
@inline function Base.copyto!(dest::StructGPUArray,
41+
bc::Broadcast.Broadcasted{Nothing})
42+
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
43+
isempty(dest) && return dest
44+
bc′ = Broadcast.preprocess(dest, bc)
45+
46+
# grid-stride kernel
47+
function broadcast_kernel(ctx, dest, bc′, nelem)
48+
for i in 1:nelem
49+
I = GPUArrays.@cartesianidx(dest, i)
50+
@inbounds dest[I] = bc′[I]
51+
end
52+
return
53+
end
54+
heuristic = GPUArrays.launch_heuristic(GPUArrays.backend(dest),
55+
broadcast_kernel, dest, bc′, 1)
56+
config = GPUArrays.launch_configuration(GPUArrays.backend(dest),
57+
heuristic, length(dest),
58+
typemax(Int))
59+
GPUArrays.gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread;
60+
threads=config.threads, blocks=config.blocks)
61+
62+
return dest
63+
end
64+
65+
function Base.similar(bc::Broadcast.Broadcasted{StructArrays.StructArrayStyle{S}},
66+
::Type{T}) where {S<:CUDA.CuArrayStyle,T}
67+
if isstructtype(T)
68+
return StructArrays.buildfromschema(typ -> similar(CuArray{typ},
69+
axes(bc)), T)
70+
else
71+
return similar(CuArray{T}, axes(bc))
72+
end
73+
end
74+
75+
# We follow GPUArrays approach of coping the whole array to the host when
76+
# outputting a StructArray backed by GPU arrays.
77+
convert_to_cpu(xs) = adapt(Array, xs)
78+
function Base.print_array(io::IO, X::StructArray{<:Any,0})
79+
X = convert_to_cpu(X)
80+
isassigned(X) ? show(io, X[]) : print(io, undef_ref_str)
81+
end
82+
Base.print_array(io::IO, X::StructArray{<:Any,1}) =
83+
Base.print_matrix(io, convert_to_cpu(X))
84+
Base.print_array(io::IO, X::StructArray{<:Any,2}) where {T} =
85+
Base.print_matrix(io, convert_to_cpu(X))
86+
Base.print_array(io::IO, X::StructArray{<:Any,<:Any}) =
87+
Base.show_nd(io, convert_to_cpu(X), Base.print_matrix, true)
88+
89+
# These definitions allow `StructArray` and `StaticArrays.SArray` to play nicely
90+
# together.
91+
StructArrays.staticschema(::Type{SArray{S,T,N,L}}) where {S,T,N,L} = NTuple{L,T}
92+
StructArrays.createinstance(::Type{SArray{S,T,N,L}}, args...) where {S,T,N,L} =
93+
SArray{S,T,N,L}(args...)
94+
StructArrays.component(s::SArray, i) = getindex(s, i)
95+
96+
@kernel function fill_kernel!(A, x)
97+
I = @index(Global)
98+
@inbounds A[I] = x
99+
end
100+
101+
function Base.fill!(A::StructArray, x)
102+
event = Event(device(A))
103+
event = fill_kernel!(device(A), 256)(A, x, ndrange = length(A),
104+
dependencies = (event, ))
105+
wait(event)
106+
end
107+
108+
function device(s::StructArray)
109+
@show ds = map(device, StructArrays.components(s))
110+
@show I = all(map(isequal(first(ds)), ds))
111+
I || throw("device error")
112+
return first(ds)
113+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
16+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1819
WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192"

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using LinearAlgebra
1515
using Random
1616
using SparseArrays
1717
using StaticArrays
18+
using StructArrays
1819
using Tullio
1920
using WriteVTK
2021

@@ -31,6 +32,7 @@ include("partitions.jl")
3132
include("permutations.jl")
3233
include("quadratures.jl")
3334
include("sparsearrays.jl")
35+
include("structarrays.jl")
3436
include("tuples.jl")
3537

3638
@testset "examples" begin

test/structarrays.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
@kernel function dup_kernel!(A, @Const(B))
2+
I = @index(Global)
3+
@inbounds A[I] = B[I] + B[I]
4+
end
5+
6+
@kernel function split_kernel!(A, B, @Const(AB))
7+
I = @index(Global)
8+
@inbounds A[I] = AB[I].a
9+
@inbounds B[I] = AB[I].b
10+
nothing
11+
end
12+
13+
@kernel function join_kernel!(AB, @Const(A), @Const(B))
14+
I = @index(Global)
15+
@inbounds AB[I] = (a=A[I], b=B[I])
16+
end
17+
18+
@testset "StructArrays" begin
19+
TAs = ((Float64, Array), (BigFloat, Array))
20+
if CUDA.has_cuda_gpu()
21+
TAs = (TAs..., (Float32, CuArray))
22+
end
23+
24+
for (T, A) in TAs
25+
N = 4
26+
C = 3
27+
M = 5
28+
29+
X = adapt(A, collect(reshape(oneunit(T):N*C*M, N, C, M)))
30+
31+
tup = ntuple(i->view(X, :, i, :), C)
32+
b = StructArray{SMatrix{C, 1, T, C}}(tup)
33+
@test b isa StructArray
34+
@test Bennu.device(b) == Bennu.device(A)
35+
36+
aos_b = collect(adapt(Array, b))
37+
@test aos_b isa Array
38+
@test Bennu.device(aos_b) == Bennu.device(Array)
39+
40+
a = similar(b)
41+
aos_a = copy(aos_b)
42+
fill!(a, zero(eltype(a)))
43+
@test all(iszero.(a))
44+
@test a isa StructArray
45+
@test Bennu.device(a) == Bennu.device(A)
46+
47+
a .= b .+ b
48+
@test isapprox(collect(adapt(Array, a)), aos_b .+ aos_b)
49+
50+
ab = StructArray((a=a, b=b))
51+
@test ab isa StructArray
52+
@test Bennu.device(ab) == Bennu.device(A)
53+
@test all(ab.a .== a)
54+
@test all(ab.b .== b)
55+
56+
ab2 = similar(ab, Int)
57+
@test eltype(ab2) == Int
58+
@test ab2 isa A
59+
@test Bennu.device(ab2) == Bennu.device(A)
60+
61+
e = b .* 3
62+
@test isapprox(collect(adapt(Array, e)), aos_b .* 3 )
63+
@test e isa StructArray
64+
@test Bennu.device(e) == Bennu.device(A)
65+
66+
c = norm.(b)
67+
@test isapprox(collect(adapt(Array, c)), norm.(aos_b))
68+
@test c isa A
69+
70+
for s in ((), (1,), (1,3), (1,3,2))
71+
d = similar(b, SVector{2, Float32}, s)
72+
@test size(d) == s
73+
@test d isa StructArray
74+
@test Bennu.device(d) == Bennu.device(A)
75+
end
76+
77+
a = similar(b)
78+
fill!(a, zero(eltype(a)))
79+
event = Event(Bennu.device(A))
80+
event = dup_kernel!(Bennu.device(A), 256)(a, b, ndrange=length(a),
81+
dependencies = (event, ))
82+
wait(event)
83+
@test all(a .== b)
84+
85+
fill!(a, zero(eltype(a)))
86+
fill!(b, zero(eltype(b)))
87+
event = Event(Bennu.device(A))
88+
event = split_kernel!(Bennu.device(A), 256)(a, b, ab, ndrange=length(a),
89+
dependencies = (event, ))
90+
wait(event)
91+
@test all(ab.a .== a)
92+
@test all(ab.b .== b)
93+
94+
fill!(ab, (a=zero(eltype(a)), b=zero(eltype(b))))
95+
event = Event(Bennu.device(A))
96+
event = join_kernel!(Bennu.device(A), 256)(ab, a, b, ndrange=length(a),
97+
dependencies = (event, ))
98+
wait(event)
99+
@test all(ab.a .== a)
100+
@test all(ab.b .== b)
101+
end
102+
end

0 commit comments

Comments
 (0)