diff --git a/src/gather.jl b/src/gather.jl index d75f89a2..d84015de 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -78,6 +78,7 @@ end to_cartesian_index(IJK...) = CartesianIndex.(IJK...) @non_differentiable to_cartesian_index(::Any...) + """ NNlib.gather!(dst, src, idx) @@ -109,6 +110,41 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) return dst end +""" + gather!(dst, src, IJK...) + +Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and +call `gather!` on it: `gather!(dst, src, CartesianIndex.(IJK...))`. + +# Examples + +```jldoctest +julia> src = reshape([1:15;], 3, 5) +3×5 Matrix{Int64}: + 1 4 7 10 13 + 2 5 8 11 14 + 3 6 9 12 15 + +julia> dst = similar(src, 2) +2-element Vector{Int64}: + 0 + 0 + +julia> NNlib.gather!(dst, src, [1, 2], [2, 4]) +2-element Vector{Int64}: + 4 + 11 +``` +""" +function gather!( + dst::AbstractArray{Tdst, Ndst}, + src::AbstractArray{Tsrc, Nsrc}, + I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, + Ks::AbstractVector{<:Integer}..., +) where {Ndst, Tdst, Nsrc, Tsrc} + return gather!(dst, src, to_cartesian_index(I, J, Ks...)) +end + function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) n_dims = scatter_dims(src, dst, idx) dims = size(src)[1:n_dims] diff --git a/test/testsuite/gather.jl b/test/testsuite/gather.jl index 92e3bfb7..f336b4c4 100644 --- a/test/testsuite/gather.jl +++ b/test/testsuite/gather.jl @@ -200,5 +200,21 @@ function gather_testsuite(Backend) 2 5 3 6] end + + @testset "gather!(dst, src, IJK...)" begin + x = device(reshape([1:15;], 3, 5)) + i, j = device([1,2]), device([2,4]) + dst = device(zeros(Int, 2)) + gather!(dst, x, i, j) + @test cpu(dst) == [4, 11] + + # Test with the issue example + s = device([1, 2, 3]) + t = device([2, 3, 1]) + A = device([0.0 1.0 0.0; 0.0 0.0 1.0; 1.0 0.0 0.0]) + w = device([0.0, 0.0, 0.0]) + gather!(w, A, s, t) + @test cpu(w) == [1.0, 1.0, 1.0] + end end