|  | 
|  | 1 | +# These two similar functions can be removed once | 
|  | 2 | +# <https://github.com/JuliaArrays/StructArrays.jl/pull/94> | 
|  | 3 | +# is accepted. | 
|  | 4 | +function Base.similar(s::StructArray{T,N,C}, ::Type{T}, sz::NTuple{M,Int64}) where {T,N,M,C<:Union{Tuple,NamedTuple}} | 
|  | 5 | +    return StructArray{T}(map(typ -> similar(typ, sz), | 
|  | 6 | +                              StructArrays.components(s))) | 
|  | 7 | +end | 
|  | 8 | + | 
|  | 9 | +function Base.similar(s::StructArray{T,N,C}, S::Type, sz::NTuple{M,Int64}) where {T,N,M,C<:Union{Tuple,NamedTuple}} | 
|  | 10 | +    # If not specified, we don't really know what kind of array to use for each | 
|  | 11 | +    # interior type, so we just pick the first one arbitrarily. If users need | 
|  | 12 | +    # something else, they need to be more specific. | 
|  | 13 | +    f1 = StructArrays.components(s)[1] | 
|  | 14 | +    if isstructtype(S) | 
|  | 15 | +        return StructArrays.buildfromschema(typ -> similar(f1, typ, sz), S) | 
|  | 16 | +    else | 
|  | 17 | +        return similar(f1, S, sz) | 
|  | 18 | +    end | 
|  | 19 | +end | 
|  | 20 | + | 
|  | 21 | +# The following broadcast code is slightly modified from the code found at | 
|  | 22 | +# <https://github.com/JuliaArrays/StructArrays.jl/issues/150>. | 
|  | 23 | +const GPUStore = Tuple{Vararg{GPUArrays.BroadcastGPUArray}} | 
|  | 24 | +const NamedGPUStore = NamedTuple{Name,<:GPUStore} where {Name} | 
|  | 25 | +const StructGPUArray = StructArray{T,N,<:Union{GPUStore,NamedGPUStore}} where {T,N} | 
|  | 26 | +## backend for StructArray | 
|  | 27 | +GPUArrays.backend(A::StructGPUArray) = | 
|  | 28 | +    GPUArrays.backend(StructArrays.components(A)) | 
|  | 29 | +GPUArrays.backend(t::GPUStore) = GPUArrays.backend(typeof(t)) | 
|  | 30 | +GPUArrays.backend(nt::NamedGPUStore) = | 
|  | 31 | +    GPUArrays.backend(typeof(nt).parameters[2]) | 
|  | 32 | +function GPUArrays.backend(::Type{T}) where {T<:GPUStore} | 
|  | 33 | +    bs = GPUArrays.backend.(tuple(T.parameters...)) | 
|  | 34 | +    I = all(map(isequal(first(bs)), bs)) | 
|  | 35 | +    I || throw("device error") | 
|  | 36 | +    GPUArrays.backend(T.parameters[1]) | 
|  | 37 | +end | 
|  | 38 | + | 
|  | 39 | +## copy from GPUArrays | 
|  | 40 | +@inline function Base.copyto!(dest::StructGPUArray, | 
|  | 41 | +                              bc::Broadcast.Broadcasted{Nothing}) | 
|  | 42 | +    axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) | 
|  | 43 | +    isempty(dest) && return dest | 
|  | 44 | +    bc′ = Broadcast.preprocess(dest, bc) | 
|  | 45 | + | 
|  | 46 | +    # grid-stride kernel | 
|  | 47 | +    function broadcast_kernel(ctx, dest, bc′, nelem) | 
|  | 48 | +        for i in 1:nelem | 
|  | 49 | +            I = GPUArrays.@cartesianidx(dest, i) | 
|  | 50 | +            @inbounds dest[I] = bc′[I] | 
|  | 51 | +        end | 
|  | 52 | +        return | 
|  | 53 | +    end | 
|  | 54 | +    heuristic = GPUArrays.launch_heuristic(GPUArrays.backend(dest), | 
|  | 55 | +                                           broadcast_kernel, dest, bc′, 1) | 
|  | 56 | +    config = GPUArrays.launch_configuration(GPUArrays.backend(dest), | 
|  | 57 | +                                            heuristic, length(dest), | 
|  | 58 | +                                            typemax(Int)) | 
|  | 59 | +    GPUArrays.gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread; | 
|  | 60 | +                       threads=config.threads, blocks=config.blocks) | 
|  | 61 | + | 
|  | 62 | +    return dest | 
|  | 63 | +end | 
|  | 64 | + | 
|  | 65 | +function Base.similar(bc::Broadcast.Broadcasted{StructArrays.StructArrayStyle{S}}, | 
|  | 66 | +                      ::Type{T}) where {S<:CUDA.CuArrayStyle,T} | 
|  | 67 | +    if isstructtype(T) | 
|  | 68 | +        return StructArrays.buildfromschema(typ -> similar(CuArray{typ}, | 
|  | 69 | +                                                           axes(bc)), T) | 
|  | 70 | +    else | 
|  | 71 | +        return similar(CuArray{T}, axes(bc)) | 
|  | 72 | +    end | 
|  | 73 | +end | 
|  | 74 | + | 
|  | 75 | +# We follow GPUArrays approach of coping the whole array to the host when | 
|  | 76 | +# outputting a StructArray backed by GPU arrays. | 
|  | 77 | +convert_to_cpu(xs) = adapt(Array, xs) | 
|  | 78 | +function Base.print_array(io::IO, X::StructArray{<:Any,0}) | 
|  | 79 | +    X = convert_to_cpu(X) | 
|  | 80 | +    isassigned(X) ? show(io, X[]) : print(io, undef_ref_str) | 
|  | 81 | +end | 
|  | 82 | +Base.print_array(io::IO, X::StructArray{<:Any,1}) = | 
|  | 83 | +    Base.print_matrix(io, convert_to_cpu(X)) | 
|  | 84 | +Base.print_array(io::IO, X::StructArray{<:Any,2}) where {T} = | 
|  | 85 | +    Base.print_matrix(io, convert_to_cpu(X)) | 
|  | 86 | +Base.print_array(io::IO, X::StructArray{<:Any,<:Any}) = | 
|  | 87 | +    Base.show_nd(io, convert_to_cpu(X), Base.print_matrix, true) | 
|  | 88 | + | 
|  | 89 | +# These definitions allow `StructArray` and `StaticArrays.SArray` to play nicely | 
|  | 90 | +# together. | 
|  | 91 | +StructArrays.staticschema(::Type{SArray{S,T,N,L}}) where {S,T,N,L} = NTuple{L,T} | 
|  | 92 | +StructArrays.createinstance(::Type{SArray{S,T,N,L}}, args...) where {S,T,N,L} = | 
|  | 93 | +    SArray{S,T,N,L}(args...) | 
|  | 94 | +StructArrays.component(s::SArray, i) = getindex(s, i) | 
|  | 95 | + | 
|  | 96 | +@kernel function fill_kernel!(A, x) | 
|  | 97 | +    I = @index(Global) | 
|  | 98 | +    @inbounds A[I] = x | 
|  | 99 | +end | 
|  | 100 | + | 
|  | 101 | +function Base.fill!(A::StructArray, x) | 
|  | 102 | +    event = Event(device(A)) | 
|  | 103 | +    event = fill_kernel!(device(A), 256)(A, x, ndrange = length(A), | 
|  | 104 | +                                         dependencies = (event, )) | 
|  | 105 | +    wait(event) | 
|  | 106 | +end | 
|  | 107 | + | 
|  | 108 | +function device(s::StructArray) | 
|  | 109 | +    @show ds = map(device, StructArrays.components(s)) | 
|  | 110 | +    @show I = all(map(isequal(first(ds)), ds)) | 
|  | 111 | +    I || throw("device error") | 
|  | 112 | +    return first(ds) | 
|  | 113 | +end | 
0 commit comments