|  | 
|  | 1 | +## Base interface | 
|  | 2 | + | 
|  | 3 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Nothing) = | 
|  | 4 | +    AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output))) | 
|  | 5 | + | 
|  | 6 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Nothing) = | 
|  | 7 | +    AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output))) | 
|  | 8 | + | 
|  | 9 | +Base._accumulate!(op, output::AnyGPUArray, input::MtlVector, dims::Nothing, init::Some) = | 
|  | 10 | +    AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init)) | 
|  | 11 | + | 
|  | 12 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Some) = | 
|  | 13 | +    AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init)) | 
|  | 14 | + | 
|  | 15 | +Base.accumulate_pairwise!(op, result::AnyGPUVector, v::AnyGPUVector) = accumulate!(op, result, v) | 
|  | 16 | + | 
|  | 17 | +# default behavior unless dims are specified by the user | 
|  | 18 | +function Base.accumulate(op, A::WrappedGPUArray; | 
|  | 19 | +                         dims::Union{Nothing,Integer}=nothing, kw...) | 
|  | 20 | +    nt = values(kw) | 
|  | 21 | +    if dims === nothing && !(A isa AbstractVector) | 
|  | 22 | +        # This branch takes care of the cases not handled by `_accumulate!`. | 
|  | 23 | +        return reshape(AK.accumulate(op, A[:], get_backend(A); init = (:init in keys(kw) ? nt.init : AK.neutral_element(op, eltype(A)))), size(A)) | 
|  | 24 | +    end | 
|  | 25 | +    if isempty(kw) | 
|  | 26 | +        out = similar(A, Base.promote_op(op, eltype(A), eltype(A))) | 
|  | 27 | +        init = AK.neutral_element(op, eltype(out)) | 
|  | 28 | +    elseif keys(nt) === (:init,) | 
|  | 29 | +        out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A))) | 
|  | 30 | +        init = nt.init | 
|  | 31 | +    else | 
|  | 32 | +        throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))")) | 
|  | 33 | +    end | 
|  | 34 | +    AK.accumulate!(op, out, A, get_backend(A); dims, init) | 
|  | 35 | +end | 
0 commit comments