Skip to content

Commit 8d33937

Browse files
committed
Integrte dimension into IndicesInfo
1 parent dadccd9 commit 8d33937

File tree

6 files changed

+64
-62
lines changed

6 files changed

+64
-62
lines changed

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -671,51 +671,55 @@ ndims_shape(x) = ndims_shape(typeof(x))
671671
end
672672

673673
"""
674-
IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{NI,NS,IS}()
674+
IndicesInfo{N}(T::Type{<:Tuple}) -> IndicesInfo{N,NI,NS}()
675675
676676
Provides basic trait information for each index type in in the tuple `T`. `NI`, `NS`, and
677677
`IS` are tuples of [`ndims_index`](@ref), [`ndims_shape`](@ref), and
678678
[`is_splat_index`](@ref) (respectively) for each field of `T`.
679+
680+
# Examples
681+
682+
```julia
683+
julia> using ArrayInterfaceCore: IndicesInfo
684+
685+
julia> IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))
686+
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
687+
688+
```
679689
"""
680-
struct IndicesInfo{NI,NS,IS} end
681-
IndicesInfo(x::Union{Tuple,SubArray}) = IndicesInfo(typeof(x))
682-
function IndicesInfo(@nospecialize T::Type{<:Tuple})
683-
IndicesInfo{
684-
map_tuple_type(ndims_index, T),
685-
map_tuple_type(ndims_shape, T),
686-
_find_first_true(map_tuple_type(is_splat_index, T))
687-
}()
690+
struct IndicesInfo{N,NI,NS} end
691+
IndicesInfo(x::SubArray) = IndicesInfo{ndims(parent(x))}(typeof(x.indices))
692+
@inline function IndicesInfo(@nospecialize T::Type{<:SubArray})
693+
IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices))
688694
end
689-
function IndicesInfo(@nospecialize T::Type{<:SubArray})
690-
indices_to_dims(IndicesInfo(fieldtype(T, :indices)), parent_type(T))
695+
function IndicesInfo{N}(@nospecialize(T::Type{<:Tuple})) where {N}
696+
_indices_info(
697+
Val{_find_first_true(map_tuple_type(is_splat_index, T))}(),
698+
IndicesInfo{N,map_tuple_type(ndims_index, T),map_tuple_type(ndims_shape, T)}()
699+
)
691700
end
692-
693-
@inline function indices_to_dims(@nospecialize(I::Type{<:Tuple}), @nospecialize(T::Type))
694-
indices_to_dims(IndicesInfo(I), T)
701+
function _indices_info(::Val{nothing}, ::IndicesInfo{1,(1,),NS}) where {NS}
702+
ns1 = getfield(NS, 1)
703+
IndicesInfo{1,(1,), (ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
695704
end
696-
@inline function indices_to_dims(::IndicesInfo{(1,),NS,nothing}, @nospecialize(T::Type)) where {NS}
705+
function _indices_info(::Val{nothing}, ::IndicesInfo{N,(1,),NS}) where {N,NS}
697706
ns1 = getfield(NS, 1)
698-
nsd = ns1 > 1 ? ntuple(identity, ns1) : ns1
699-
if ndims(T) === 1
700-
IndicesInfo{(1,), (nsd,), nothing}()
701-
else
702-
IndicesInfo{(:,), (nsd,), nothing}()
703-
end
707+
IndicesInfo{N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
704708
end
705-
@inline function indices_to_dims(::IndicesInfo{NI,NS,nothing}, @nospecialize(T::Type)) where {NI,NS}
706-
if sum(NI) > ndims(T)
707-
IndicesInfo{_replace_trailing(ndims(T), _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS), nothing}()
709+
@inline function _indices_info(::Val{nothing}, ::IndicesInfo{N,NI,NS}) where {N,NI,NS}
710+
if sum(NI) > N
711+
IndicesInfo{N,_replace_trailing(N, _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS)}()
708712
else
709-
IndicesInfo{_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS), nothing}()
713+
IndicesInfo{N,_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS)}()
710714
end
711715
end
712-
@inline function indices_to_dims(::IndicesInfo{NI,NS,SI}, @nospecialize(T::Type)) where {NI,NS,SI}
713-
nsplat = ndims(T) - sum(NI)
716+
@inline function _indices_info(::Val{SI}, ::IndicesInfo{N,NI,NS}) where {N,NI,NS,SI}
717+
nsplat = N - sum(NI)
714718
if nsplat === 0
715-
indices_to_dims(IndicesInfo{NI,NS,nothing}(), T)
719+
_indices_info(Val{nothing}(), IndicesInfo{N,NI,NS}())
716720
else
717721
splatmul = max(0, nsplat + 1)
718-
indices_to_dims(IndicesInfo{_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS),nothing}(), T)
722+
_indices_info(Val{nothing}(), IndicesInfo{N,_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS)}())
719723
end
720724
end
721725
@inline function _map_splats(nsplat::Int, splat_index::Int, dims::Tuple{Vararg{Int}})

lib/ArrayInterfaceCore/test/runtests.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -294,35 +294,35 @@ end
294294
@test !ArrayInterfaceCore.indices_do_not_alias(typeof(view(rand(4,4)', StepRangeLen(1,0,5), 1:2)))
295295
end
296296

297-
@testset "indices_to_dims" begin
297+
@testset "IndicesInfo" begin
298298

299299
struct SplatFirst end
300300

301301
ArrayInterfaceCore.is_splat_index(::Type{SplatFirst}) = true
302302

303303
@test @inferred(IndicesInfo(SubArray{Float64, 2, Vector{Float64}, Tuple{Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}, true})) ==
304-
IndicesInfo{(1,),((1,2),),nothing}()
304+
IndicesInfo{1,(1,),((1,2),)}()
305305

306-
@test @inferred(indices_to_dims(Tuple{Vector{Int}}, Vector{Int})) == IndicesInfo{(1,), (1,), nothing}()
306+
@test @inferred(IndicesInfo{1}((Tuple{Vector{Int}}))) == IndicesInfo{1, (1,), (1,)}()
307307

308-
@test @inferred(indices_to_dims(Tuple{Vector{Int}}, Matrix{Int})) == IndicesInfo{(:,), (1,), nothing}()
308+
@test @inferred(IndicesInfo{2}(Tuple{Vector{Int}})) == IndicesInfo{2, (:,), (1,)}()
309309

310-
@test @inferred(indices_to_dims(Tuple{SplatFirst}, Vector{Int})) == IndicesInfo{(1,), (1,), nothing}()
310+
@test @inferred(IndicesInfo{1}(Tuple{SplatFirst})) == IndicesInfo{1, (1,), (1,)}()
311311

312-
@test @inferred(indices_to_dims(Tuple{SplatFirst}, Matrix{Int})) == IndicesInfo{((1,2),), ((1, 2),), nothing}()
312+
@test @inferred(IndicesInfo{2}(Tuple{SplatFirst})) == IndicesInfo{2, ((1,2),), ((1, 2),)}()
313313

314-
@test @inferred(indices_to_dims(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)), Array{Int,5})) ==
315-
IndicesInfo{(1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0), nothing}()
314+
@test @inferred(IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))) ==
315+
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
316316

317-
@test @inferred(indices_to_dims(Tuple{Vararg{Int,10}}, Array{Int,10})) ==
318-
IndicesInfo{(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0), nothing}()
317+
@test @inferred(IndicesInfo{10}(Tuple{Vararg{Int,10}})) ==
318+
IndicesInfo{10, (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)}()
319319

320-
@test @inferred(indices_to_dims(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)),Array{Int,10})) ==
321-
IndicesInfo{(1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0), nothing}()
320+
@test @inferred(IndicesInfo{10}(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)))) ==
321+
IndicesInfo{10, (1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0)}()
322322

323-
@test @inferred(indices_to_dims(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)), Array{Int,10})) ==
324-
IndicesInfo{((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0), nothing}()
323+
@test @inferred(IndicesInfo{10}(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)))) ==
324+
IndicesInfo{10, ((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0)}()
325325

326-
@test @inferred(indices_to_dims(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))), Array{Int,10})) ==
327-
IndicesInfo{(1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0), nothing}()
326+
@test @inferred(IndicesInfo{10}(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))))) ==
327+
IndicesInfo{10, (1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0)}()
328328
end

src/ArrayInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
77
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
88
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo,
9-
map_tuple_type, flatten_tuples, GetIndex, indices_to_dims
9+
map_tuple_type, flatten_tuples, GetIndex
1010

1111
# ArrayIndex subtypes and methods
1212
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex

src/dimensions.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11

22

3-
_init_dimsmap(::Type{I}, ::Type{P}) where {I,P} = _init_dimsmap(indices_to_dims(I, P))
43
_init_dimsmap(x) = _init_dimsmap(IndicesInfo(x))
5-
function _init_dimsmap(::IndicesInfo{pdims,cdims,nothing}) where {pdims,cdims}
4+
function _init_dimsmap(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
65
ntuple(i -> static(getfield(pdims, i)), length(pdims)),
76
ntuple(i -> static(getfield(cdims, i)), length(pdims))
87
end
@@ -21,9 +20,9 @@ Returns the mapping from child dimensions to parent dimensions.
2120
"""
2221
to_parent_dims(@nospecialize x) = to_parent_dims(typeof(x))
2322
@inline function to_parent_dims(@nospecialize T::Type{<:SubArray})
24-
to_parent_dims(indices_to_dims(fieldtype(T, :indices), parent_type(T)))
23+
to_parent_dims(IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices)))
2524
end
26-
function to_parent_dims(::IndicesInfo{pdims,cdims,nothing}) where {pdims,cdims}
25+
function to_parent_dims(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
2726
flatten_tuples(ntuple(length(cdims)) do i
2827
cdim_i = getfield(cdims, i)
2928
if cdim_i isa Tuple
@@ -48,7 +47,7 @@ end
4847
# Base will sometomes demote statically known slices in `SubArray` to `OneTo{Int}` so we
4948
# provide the parent mapping to check for static size info
5049
sub_axes_map(@nospecialize(T::Type{<:SubArray})) = _sub_axes_map(T, IndicesInfo(T))
51-
function _sub_axes_map(@nospecialize(T::Type{<:SubArray}), ::IndicesInfo{pdims}) where {pdims}
50+
function _sub_axes_map(@nospecialize(T::Type{<:SubArray}), ::IndicesInfo{N,pdims}) where {N,pdims}
5251
ntuple(length(pdims)) do i
5352
if fieldtype(fieldtype(T, :indices), i) <: Base.Slice{OneTo{Int}}
5453
sz = known_size(parent_type(T), getfield(pdims, i))
@@ -60,7 +59,7 @@ function _sub_axes_map(@nospecialize(T::Type{<:SubArray}), ::IndicesInfo{pdims})
6059
end
6160

6261
sub_dimnames_map(@nospecialize T::Type{<:SubArray}) = _sub_dimnames_map(IndicesInfo(T))
63-
function _sub_dimnames_map(::IndicesInfo{pdims,cdims}) where {pdims,cdims}
62+
function _sub_dimnames_map(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
6463
ntuple(length(pdims)) do i
6564
cdim_i = getfield(cdims, i)
6665
pdim_i = getfield(pdims, i)
@@ -87,10 +86,10 @@ from_parent_dims(@nospecialize x) = from_parent_dims(typeof(x))
8786
from_parent_dims(@nospecialize T::Type{<:PermutedDimsArray}) = getfield(_permdims(T), 2)
8887
from_parent_dims(@nospecialize T::Type{<:MatAdjTrans}) = (StaticInt(2), StaticInt(1))
8988
@inline function from_parent_dims(@nospecialize T::Type{<:SubArray})
90-
from_parent_dims(indices_to_dims(fieldtype(T, :indices), parent_type(T)))
89+
from_parent_dims(IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices)))
9190
end
9291
# TODO do I need to flatten_tuples here?
93-
function from_parent_dims(::IndicesInfo{pdims,cdims,nothing}) where {pdims,cdims}
92+
function from_parent_dims(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
9493
ntuple(length(cdims)) do i
9594
pdim_i = getfield(pdims, i)
9695
cdim_i = static(getfield(cdims, i))
@@ -130,7 +129,6 @@ function _known_sub_dimname(::Tuple, ::Pair{StaticSymbol{:underscore},StaticInt{
130129
ntuple(Compat.Returns(:_), StaticInt(N))
131130
end
132131

133-
134132
function known_dimnames(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
135133
pnames = known_dimnames(A)
136134
if IsReshaped

src/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ This implementation differs from that of `Base.to_indices` in the following ways
8181
"""
8282
to_indices(A, ::Tuple{}) = ()
8383
@inline function to_indices(a::A, inds::I) where {A,I}
84-
flatten_tuples(map(IndexedMappedArray(a), inds, getfield(_init_dimsmap(I, A), 1)))
84+
flatten_tuples(map(IndexedMappedArray(a), inds, getfield(_init_dimsmap(IndicesInfo{ndims(A)}(I)), 1)))
8585
end
8686

8787
struct IndexedMappedArray{A}

src/stridelayout.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ contiguous_axis(::Type{<:StrideIndex{N,R,nothing}}) where {N,R} = nothing
125125
function contiguous_axis(::Type{T}) where {T}
126126
is_forwarding_wrapper(T) ? contiguous_axis(parent_type(T)) : nothing
127127
end
128-
contiguous_axis(::Type{<:DenseArray}) = One()
128+
contiguous_axis(@nospecialize T::Type{<:DenseArray}) = One()
129129
contiguous_axis(::Type{<:BitArray}) = One()
130-
contiguous_axis(::Type{<:AbstractRange}) = One()
131-
contiguous_axis(::Type{<:Tuple}) = One()
132-
function contiguous_axis(::Type{T}) where {T<:VecAdjTrans}
130+
contiguous_axis(@nospecialize T::Type{<:AbstractRange}) = One()
131+
contiguous_axis(@nospecialize T::Type{<:Tuple}) = One()
132+
function contiguous_axis(@nospecialize T::Type{<:VecAdjTrans})
133133
c = contiguous_axis(parent_type(T))
134134
if c === nothing
135135
return nothing
@@ -139,7 +139,7 @@ function contiguous_axis(::Type{T}) where {T<:VecAdjTrans}
139139
return -One()
140140
end
141141
end
142-
function contiguous_axis(::Type{T}) where {T<:MatAdjTrans}
142+
function contiguous_axis(@nospecialize T::Type{<:MatAdjTrans})
143143
c = contiguous_axis(parent_type(T))
144144
if c === nothing
145145
return nothing
@@ -149,7 +149,7 @@ function contiguous_axis(::Type{T}) where {T<:MatAdjTrans}
149149
return StaticInt(3) - c
150150
end
151151
end
152-
function contiguous_axis(::Type{T}) where {T<:PermutedDimsArray}
152+
function contiguous_axis(@nospecialize T::Type{<:PermutedDimsArray})
153153
c = contiguous_axis(parent_type(T))
154154
if c === nothing
155155
return nothing

0 commit comments

Comments
 (0)