Skip to content

Commit 4334830

Browse files
committed
Makegetindex works the same as Base
1 parent 57ef2c7 commit 4334830

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

src/indexing.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ to_index(::MyIndexStyle, axis, arg) = ...
183183
"""
184184
to_index(x, i::Slice) = i
185185
to_index(x, ::Colon) = indices(x)
186+
to_index(::LinearIndices{0,Tuple{}}, ::Colon) = Slice(static(1):static(1))
187+
to_index(::CartesianIndices{0,Tuple{}}, ::Colon) = Slice(static(1):static(1))
186188
# logical indexing
187189
to_index(x, i::AbstractArray{Bool}) = LogicalIndex(i)
188190
to_index(x::LinearIndices, i::AbstractArray{Bool}) = LogicalIndex{Int}(i)
@@ -251,7 +253,7 @@ indices calling [`to_axis`](@ref).
251253
end
252254
end
253255
# drop this dimension
254-
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, tail(a), tail(i))
256+
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
255257
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i)
256258
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
257259
return (to_axis(_maybe_first(axs), first(inds)), to_axes(A, _maybe_tail(axs), tail(inds))...)
@@ -354,7 +356,9 @@ unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))
354356
end
355357

356358
unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)
357-
unsafe_getindex(A::CartesianIndices, i::CanonicalInt, ii::Vararg{CanonicalInt}) = CartesianIndex(i, ii...)
359+
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{CanonicalInt,N}) where {N} = CartesianIndex(ii...)
360+
unsafe_getindex(A::CartesianIndices, ii::Vararg{CanonicalInt}) =
361+
unsafe_getindex(A, Base.front(ii)...)
358362
unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[i])
359363

360364
unsafe_getindex(A::ReshapedArray, i::CanonicalInt) = @inbounds(parent(A)[i])
@@ -381,20 +385,16 @@ function unsafe_get_collection(A, inds)
381385
end
382386
return dest
383387
end
384-
_ints2range(x::CanonicalInt) = x:x
385-
_ints2range(x::AbstractRange) = x
388+
# _ints2range(x::CanonicalInt) = x:x
389+
# _ints2range(x::AbstractRange) = x
386390
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
387-
if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()
388-
return Base._getindex(IndexStyle(A), A, inds...)
389-
else
390-
return CartesianIndices(to_axes(A, _ints2range.(inds)))
391-
end
391+
return Base._getindex(IndexStyle(A), A, inds...)
392392
end
393393
@inline function unsafe_get_collection(A::LinearIndices{N}, inds) where {N}
394394
if Base.length(inds) === 1 && isone(_ndims_index(typeof(inds), static(1)))
395395
return @inbounds(eachindex(A)[first(inds)])
396396
elseif stride_preserving_index(typeof(inds)) === True()
397-
return LinearIndices(to_axes(A, _ints2range.(inds)))
397+
return LinearIndices(to_axes(A, inds))
398398
else
399399
return Base._getindex(IndexStyle(A), A, inds...)
400400
end

test/indexing.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,34 @@ end
103103
@testset "getindex with additional inds" begin
104104
A = reshape(1:12, (3, 4))
105105
subA = view(A, :, :)
106+
LA = LinearIndices(A)
107+
CA = CartesianIndices(A)
106108
@test @inferred(ArrayInterface.getindex(A, 1, 1, 1)) == 1
107109
@test @inferred(ArrayInterface.getindex(A, 1, 1, :)) == [1]
110+
@test @inferred(ArrayInterface.getindex(A, 1, 1, 1:1)) == [1]
108111
@test @inferred(ArrayInterface.getindex(A, 1, 1, :, :)) == ones(1, 1)
112+
@test @inferred(ArrayInterface.getindex(A, :, 1, 1)) == 1:3
113+
@test @inferred(ArrayInterface.getindex(A, :, 1, :)) == reshape(1:3, 3, 1)
109114
@test @inferred(ArrayInterface.getindex(subA, 1, 1, 1)) == 1
110115
@test @inferred(ArrayInterface.getindex(subA, 1, 1, :)) == [1]
116+
@test @inferred(ArrayInterface.getindex(subA, 1, 1, 1:1)) == [1]
111117
@test @inferred(ArrayInterface.getindex(subA, 1, 1, :, :)) == ones(1, 1)
118+
@test @inferred(ArrayInterface.getindex(subA, :, 1, 1)) == 1:3
119+
@test @inferred(ArrayInterface.getindex(subA, :, 1, :)) == reshape(1:3, 3, 1)
120+
@test @inferred(ArrayInterface.getindex(LA, 1, 1, 1)) == 1
121+
@test @inferred(ArrayInterface.getindex(LA, 1, 1, :)) == [1]
122+
@test @inferred(ArrayInterface.getindex(LA, 1, 1, 1:1)) == [1]
123+
@test @inferred(ArrayInterface.getindex(LA, 1, 1, :, :)) == ones(1, 1)
124+
@test @inferred(ArrayInterface.getindex(LA, :, 1, 1)) == 1:3
125+
@test @inferred(ArrayInterface.getindex(LA, :, 1, :)) == reshape(1:3, 3, 1)
126+
@test @inferred(ArrayInterface.getindex(CA, 1, 1, 1)) == CartesianIndex(1, 1)
127+
@test @inferred(ArrayInterface.getindex(CA, 1, 1, :)) == [CartesianIndex(1, 1)]
128+
@test @inferred(ArrayInterface.getindex(CA, 1, 1, 1:1)) == [CartesianIndex(1, 1)]
129+
@test @inferred(ArrayInterface.getindex(CA, 1, 1, :, :)) == fill(CartesianIndex(1, 1), 1, 1)
130+
@test @inferred(ArrayInterface.getindex(CA, :, 1, 1)) ==
131+
reshape(CartesianIndex(1, 1):CartesianIndex(3, 1), 3)
132+
@test @inferred(ArrayInterface.getindex(CA, :, 1, :)) ==
133+
reshape(CartesianIndex(1, 1):CartesianIndex(3, 1), 3, 1)
112134
end
113135

114136
@testset "0-dimensional" begin

0 commit comments

Comments
 (0)