Skip to content

Commit 6f8ba49

Browse files
authored
Add methods to index identityUnitRange/Slice with another IdentityUnitRange (#41224)
Adding these methods lets `OffsetArrays` define `getindex(::AbstractUnitRange, ::IdentityUnitRange)` without ambiguities. This is in the domain of sanctioned type-piracy, as the result is an offset range in general and cannot be represented correctly using `Base` types. Re: JuliaArrays/OffsetArrays.jl#244 cc: @johnnychen94 Edit: this also fixes an indexing bug in `IdentityUntiRange`: master ```julia julia> r = Base.IdentityUnitRange(-3:3) Base.IdentityUnitRange(-3:3) julia> r[2] 2 julia> r[big(2)] -2 ``` Co-authored-by: jishnub <[email protected]>
1 parent 18a2e70 commit 6f8ba49

File tree

2 files changed

+86
-11
lines changed

2 files changed

+86
-11
lines changed

base/indices.jl

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,15 +423,57 @@ first(S::IdentityUnitRange) = first(S.indices)
423423
last(S::IdentityUnitRange) = last(S.indices)
424424
size(S::IdentityUnitRange) = (length(S.indices),)
425425
length(S::IdentityUnitRange) = length(S.indices)
426-
getindex(S::IdentityUnitRange, i::Int) = (@inline; @boundscheck checkbounds(S, i); i)
427-
getindex(S::IdentityUnitRange, i::AbstractUnitRange{<:Integer}) = (@inline; @boundscheck checkbounds(S, i); i)
428-
getindex(S::IdentityUnitRange, i::StepRange{<:Integer}) = (@inline; @boundscheck checkbounds(S, i); i)
426+
unsafe_length(S::IdentityUnitRange) = unsafe_length(S.indices)
427+
getindex(S::IdentityUnitRange, i::Integer) = (@inline; @boundscheck checkbounds(S, i); convert(eltype(S), i))
428+
getindex(S::IdentityUnitRange, i::Bool) = throw(ArgumentError("invalid index: $i of type Bool"))
429+
function getindex(S::IdentityUnitRange, i::AbstractUnitRange{<:Integer})
430+
@inline
431+
@boundscheck checkbounds(S, i)
432+
return convert(AbstractUnitRange{eltype(S)}, i)
433+
end
434+
function getindex(S::IdentityUnitRange, i::AbstractUnitRange{Bool})
435+
@inline
436+
@boundscheck checkbounds(S, i)
437+
range(first(i) ? first(S) : last(S), length = last(i))
438+
end
439+
function getindex(S::IdentityUnitRange, i::StepRange{<:Integer})
440+
@inline
441+
@boundscheck checkbounds(S, i)
442+
return convert(AbstractRange{eltype(S)}, i)
443+
end
444+
function getindex(S::IdentityUnitRange, i::StepRange{Bool})
445+
@inline
446+
@boundscheck checkbounds(S, i)
447+
range(first(i) ? first(S) : last(S), length = last(i), step = Int(step(i)))
448+
end
449+
# Indexing with offset ranges should preserve the axes of the indices
450+
# however, this is only really possible in general with OffsetArrays.
451+
# In some cases, though, we may obtain correct results using Base ranges
452+
# the following methods are added to allow OffsetArrays to dispatch on the first argument without ambiguities
453+
function getindex(S::IdentityUnitRange{<:AbstractUnitRange{<:Integer}},
454+
i::IdentityUnitRange{<:AbstractUnitRange{<:Integer}})
455+
@inline
456+
@boundscheck checkbounds(S, i)
457+
return i
458+
end
459+
function getindex(S::Slice{<:AbstractUnitRange{<:Integer}},
460+
i::IdentityUnitRange{<:AbstractUnitRange{<:Integer}})
461+
@inline
462+
@boundscheck checkbounds(S, i)
463+
return i
464+
end
429465
show(io::IO, r::IdentityUnitRange) = print(io, "Base.IdentityUnitRange(", r.indices, ")")
430466
iterate(S::IdentityUnitRange, s...) = iterate(S.indices, s...)
431467

432468
# For OneTo, the values and indices of the values are identical, so this may be defined in Base.
433469
# In general such an indexing operation would produce offset ranges
434-
getindex(S::OneTo, I::IdentityUnitRange{<:AbstractUnitRange{<:Integer}}) = (@inline; @boundscheck checkbounds(S, I); I)
470+
# This should also ideally return an AbstractUnitRange{eltype(S)}, but currently
471+
# we're restricted to eltype(::IdentityUnitRange) == Int by definition
472+
function getindex(S::OneTo, I::IdentityUnitRange{<:AbstractUnitRange{<:Integer}})
473+
@inline
474+
@boundscheck checkbounds(S, I)
475+
return I
476+
end
435477

436478
"""
437479
LinearIndices(A::AbstractArray)

test/ranges.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,13 +2376,46 @@ end
23762376
@test 0.2 * (-2:2:2) == [-0.4, 0, 0.4]
23772377
end
23782378

2379-
@testset "Indexing OneTo with IdentityUnitRange" begin
2380-
for endpt in Any[10, big(10), UInt(10)]
2381-
r = Base.OneTo(endpt)
2382-
inds = Base.IdentityUnitRange(3:5)
2383-
rs = r[inds]
2384-
@test rs === inds
2385-
@test_throws BoundsError r[Base.IdentityUnitRange(-1:100)]
2379+
@testset "IdentityUnitRange indexing" begin
2380+
@testset "Indexing into an IdentityUnitRange" begin
2381+
@testset for r in Any[-1:20, Base.OneTo(20)]
2382+
ri = Base.IdentityUnitRange(r)
2383+
@test_throws "invalid index" ri[true]
2384+
@testset for s in Any[Base.OneTo(6), Base.OneTo{BigInt}(6), 3:6, big(3):big(6), 3:2:7]
2385+
@test mapreduce(==, &, ri[s], ri[s[begin]]:step(s):ri[s[end]])
2386+
@test axes(ri[s]) == axes(s)
2387+
@test eltype(ri[s]) == eltype(ri)
2388+
end
2389+
end
2390+
@testset "Bool indices" begin
2391+
r = 1:1
2392+
@test Base.IdentityUnitRange(r)[true:true] == r[true:true]
2393+
@test Base.IdentityUnitRange(r)[true:true:true] == r[true:true:true]
2394+
@test_throws BoundsError Base.IdentityUnitRange(1:2)[true:true]
2395+
@test_throws BoundsError Base.IdentityUnitRange(1:2)[true:true:true]
2396+
end
2397+
end
2398+
@testset "Indexing with IdentityUnitRange" begin
2399+
@testset "OneTo" begin
2400+
@testset for endpt in Any[10, big(12), UInt(11)]
2401+
r = Base.OneTo(endpt)
2402+
inds = Base.IdentityUnitRange(3:5)
2403+
rs = r[inds]
2404+
@test rs == inds
2405+
@test axes(rs) == axes(inds)
2406+
@test_throws BoundsError r[Base.IdentityUnitRange(-1:100)]
2407+
end
2408+
end
2409+
@testset "IdentityUnitRange" begin
2410+
@testset for r in Any[Base.IdentityUnitRange(1:4), Base.IdentityUnitRange(Base.OneTo(4)), Base.Slice(1:4), Base.Slice(Base.OneTo(4))]
2411+
@testset for s in Any[Base.IdentityUnitRange(3:3), Base.IdentityUnitRange(Base.OneTo(2)), Base.Slice(3:3), Base.Slice(Base.OneTo(2))]
2412+
rs = r[s]
2413+
@test rs == s
2414+
@test axes(rs) == axes(s)
2415+
end
2416+
@test_throws BoundsError r[Base.IdentityUnitRange(first(r):last(r) + 1)]
2417+
end
2418+
end
23862419
end
23872420
end
23882421

0 commit comments

Comments
 (0)