Skip to content

Commit 2c62ec0

Browse files
committed
More docs and protect constructors
Before making this public I wanted to ensure that we had control over the construction of `IndicesInfo` so it was guaranteed to be valid. Now it has inner struct constructors only. Also provided `ndims_index` and `ndims_shape` so all information can be accessed without directly looking at `IndicesInfo`'s parametric typing.
1 parent f6e8f5a commit 2c62ec0

File tree

2 files changed

+93
-72
lines changed

2 files changed

+93
-72
lines changed

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 77 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,9 @@ end
797797
"""
798798
IndicesInfo{N}(inds::Tuple) -> IndicesInfo{N}(typeof(inds))
799799
IndicesInfo{N}(T::Type{<:Tuple}) -> IndicesInfo{N,pdims,cdims}()
800+
IndicesInfo(inds::Tuple) -> IndicesInfo(typeof(inds))
801+
IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{maximum(pdims),pdims,cdims}()
802+
800803
801804
Maps a tuple of indices to `N` dimensions. The resulting `pdims` is a tuple where each
802805
field in `inds` (or field type in `T`) corresponds to the parent dimensions accessed.
@@ -807,12 +810,14 @@ dimensions and there are no trailing dimensions accessed. These may be accessed
807810
it is assumed that no indices are accessing trailing dimensions (which are represented as
808811
`0` in `parentdims(info)[index_position]`).
809812
810-
See also: [`parentdims`](@ref), [`childdims`](@ref)
813+
The the fields and types of `IndicesInfo` should not be accessed directly.
814+
Instead [`parentdims`](@ref), [`childdims`](@ref), [`ndims_index`](@ref), and
815+
[`ndims_shape`](@ref) should be used to extract relevant information.
811816
812817
# Examples
813818
814819
```julia
815-
julia> using ArrayInterfaceCore: IndicesInfo, parentdims, childdims
820+
julia> using ArrayInterfaceCore: IndicesInfo, parentdims, childdims, ndims_index, ndims_shape
816821
817822
julia> info = IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)));
818823
@@ -822,80 +827,90 @@ julia> parentdims(info) # the last two indices access trailing dimensions
822827
julia> childdims(info)
823828
(1, 2, 0, (3, 4), 5, 0)
824829
830+
julia> childdims(info)[3] # index 3 accesses a parent dimension but is dropped in the child array
831+
0
832+
833+
julia> ndims_index(info)
834+
5
835+
836+
julia> ndims_shape(info)
837+
5
838+
825839
julia> info = IndicesInfo(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)));
826840
827-
julia> parentdims(info)
841+
julia> parentdims(info) # assumed no trailing dimensions
828842
(1, (2, 3), 4, 5, 6, 7)
829843
830-
julia> childdims(info)[3] # index 3 accesses a parent dimension but is dropped in the child array
831-
0
844+
julia> ndims_index(info) # assumed no trailing dimensions
845+
7
832846
833847
```
834848
"""
835-
struct IndicesInfo{N,pdims,cdims} end
836-
function IndicesInfo{N}(@nospecialize(T::Type{<:Tuple})) where {N}
837-
SI = _find_first_true(map_tuple_type(is_splat_index, T))
838-
NI = map_tuple_type(ndims_index, T)
839-
NS = map_tuple_type(ndims_shape, T)
840-
if SI === nothing
841-
ndi = NI
842-
nds = NS
843-
else
844-
nsplat = N - sum(NI)
845-
if nsplat === 0
849+
struct IndicesInfo{Np,pdims,cdims,Nc}
850+
function IndicesInfo{N}(@nospecialize(T::Type{<:Tuple})) where {N}
851+
SI = _find_first_true(map_tuple_type(is_splat_index, T))
852+
NI = map_tuple_type(ndims_index, T)
853+
NS = map_tuple_type(ndims_shape, T)
854+
if SI === nothing
846855
ndi = NI
847856
nds = NS
848857
else
849-
splatmul = max(0, nsplat + 1)
850-
ndi = _map_splats(splatmul, SI, NI)
851-
nds = _map_splats(splatmul, SI, NS)
858+
nsplat = N - sum(NI)
859+
if nsplat === 0
860+
ndi = NI
861+
nds = NS
862+
else
863+
splatmul = max(0, nsplat + 1)
864+
ndi = _map_splats(splatmul, SI, NI)
865+
nds = _map_splats(splatmul, SI, NS)
866+
end
852867
end
853-
end
854-
if ndi === (1,) && N !== 1
855-
ns1 = getfield(nds, 1)
856-
IndicesInfo{N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
857-
else
858-
if sum(ndi) > N
859-
init_pdims = _accum_dims(ndi)
860-
pdims = ntuple(nfields(init_pdims)) do i
861-
dim_i = getfield(init_pdims, i)
862-
if dim_i isa Tuple
863-
ntuple(length(dim_i)) do j
864-
dim_i_j = getfield(dim_i, j)
865-
dim_i_j > N ? 0 : dim_i_j
868+
if ndi === (1,) && N !== 1
869+
ns1 = getfield(nds, 1)
870+
new{N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,),ns1}()
871+
else
872+
nds_cumsum = cumsum(nds)
873+
if sum(ndi) > N
874+
init_pdims = _accum_dims(cumsum(ndi), ndi)
875+
pdims = ntuple(nfields(init_pdims)) do i
876+
dim_i = getfield(init_pdims, i)
877+
if dim_i isa Tuple
878+
ntuple(length(dim_i)) do j
879+
dim_i_j = getfield(dim_i, j)
880+
dim_i_j > N ? 0 : dim_i_j
881+
end
882+
else
883+
dim_i > N ? 0 : dim_i
866884
end
867-
else
868-
dim_i > N ? 0 : dim_i
869885
end
886+
new{N, pdims, _accum_dims(nds_cumsum, nds), last(nds_cumsum)}()
887+
else
888+
new{N,_accum_dims(cumsum(ndi), ndi), _accum_dims(nds_cumsum, nds), last(nds_cumsum)}()
870889
end
871-
IndicesInfo{N, pdims, _accum_dims(nds)}()
872-
else
873-
IndicesInfo{N,_accum_dims(ndi), _accum_dims(nds)}()
874890
end
875891
end
892+
IndicesInfo{N}(@nospecialize(t::Tuple)) where {N} = IndicesInfo{N}(typeof(t))
893+
function IndicesInfo(@nospecialize(T::Type{<:Tuple}))
894+
ndi = map_tuple_type(ndims_index, T)
895+
nds = map_tuple_type(ndims_shape, T)
896+
ndi_sum = cumsum(ndi)
897+
nds_sum = cumsum(nds)
898+
nf = nfields(ndi_sum)
899+
pdims = _accum_dims(ndi_sum, ndi)
900+
cdims = _accum_dims(nds_sum, nds)
901+
new{getfield(ndi_sum, nf),pdims,cdims,getfield(nds_sum, nf)}()
902+
end
903+
IndicesInfo(@nospecialize t::Tuple) = IndicesInfo(typeof(t))
904+
@inline function IndicesInfo(@nospecialize T::Type{<:SubArray})
905+
IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices))
906+
end
907+
IndicesInfo(x::SubArray) = IndicesInfo{ndims(parent(x))}(typeof(x.indices))
876908
end
877-
IndicesInfo{N}(@nospecialize(t::Tuple)) where {N} = IndicesInfo{N}(typeof(t))
878-
879-
function IndicesInfo(@nospecialize(T::Type{<:Tuple}))
880-
ndi = map_tuple_type(ndims_index, T)
881-
nds = map_tuple_type(ndims_shape, T)
882-
ndi_sum = cumsum(ndi)
883-
nds_sum = cumsum(nds)
884-
pdims = _accum_dims(ndi_sum, ndi)
885-
cdims = _accum_dims(nds_sum, nds)
886-
IndicesInfo{last(ndi_sum),pdims,cdims}()
887-
end
888-
IndicesInfo(@nospecialize t::Tuple) = IndicesInfo(typeof(t))
889-
@inline function IndicesInfo(@nospecialize T::Type{<:SubArray})
890-
IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices))
891-
end
892-
IndicesInfo(x::SubArray) = IndicesInfo{ndims(parent(x))}(typeof(x.indices))
893909
@inline function _map_splats(nsplat::Int, splat_index::Int, dims::Tuple{Vararg{Int}})
894910
ntuple(length(dims)) do i
895911
i === splat_index ? (nsplat * getfield(dims, i)) : getfield(dims, i)
896912
end
897913
end
898-
@inline _accum_dims(nd::Tuple{Vararg{Int}}) = _accum_dims(cumsum(nd), nd)
899914
@inline function _accum_dims(csdims::NTuple{N,Int}, nd::NTuple{N,Int}) where {N}
900915
ntuple(N) do i
901916
nd_i = getfield(nd, i)
@@ -909,14 +924,19 @@ end
909924
end
910925
end
911926

927+
_lower_info(::IndicesInfo{Np,pdims,cdims,Nc}) where {Np,pdims,cdims,Nc} = Np,pdims,cdims,Nc
928+
929+
ndims_index(@nospecialize(info::IndicesInfo)) = getfield(_lower_info(info), 1)
930+
ndims_shape(@nospecialize(info::IndicesInfo)) = getfield(_lower_info(info), 4)
931+
912932
"""
913933
parentdims(::IndicesInfo) -> Tuple
914934
915935
Returns the parent dimension mapping from `IndicesInfo`.
916936
917937
See also: [`IndicesInfo`](@ref), [`childdims`](@ref)
918938
"""
919-
parentdims(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims} = pdims
939+
parentdims(@nospecialize info::IndicesInfo) = getfield(_lower_info(info), 2)
920940

921941
"""
922942
childdims(::IndicesInfo) -> Tuple
@@ -925,7 +945,8 @@ Returns the child dimension mapping from `IndicesInfo`.
925945
926946
See also: [`IndicesInfo`](@ref), [`parentdims`](@ref)
927947
"""
928-
childdims(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims} = cdims
948+
childdims(@nospecialize info::IndicesInfo) = getfield(_lower_info(info), 3)
949+
929950

930951
"""
931952
instances_do_not_alias(::Type{T}) -> Bool

lib/ArrayInterfaceCore/test/runtests.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -300,29 +300,29 @@ end
300300

301301
ArrayInterfaceCore.is_splat_index(::Type{SplatFirst}) = true
302302

303-
@test @inferred(IndicesInfo(SubArray{Float64, 2, Vector{Float64}, Tuple{Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}, true})) ==
304-
IndicesInfo{1,(1,),((1,2),)}()
303+
@test @inferred(IndicesInfo(SubArray{Float64, 2, Vector{Float64}, Tuple{Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}, true})) isa
304+
IndicesInfo{1,(1,),((1,2),)}
305305

306-
@test @inferred(IndicesInfo{1}((Tuple{Vector{Int}}))) == IndicesInfo{1, (1,), (1,)}()
306+
@test @inferred(IndicesInfo{1}((Tuple{Vector{Int}}))) isa IndicesInfo{1, (1,), (1,)}
307307

308-
@test @inferred(IndicesInfo{2}(Tuple{Vector{Int}})) == IndicesInfo{2, (:,), (1,)}()
308+
@test @inferred(IndicesInfo{2}(Tuple{Vector{Int}})) isa IndicesInfo{2, (:,), (1,)}
309309

310-
@test @inferred(IndicesInfo{1}(Tuple{SplatFirst})) == IndicesInfo{1, (1,), (1,)}()
310+
@test @inferred(IndicesInfo{1}(Tuple{SplatFirst})) isa IndicesInfo{1, (1,), (1,)}
311311

312-
@test @inferred(IndicesInfo{2}(Tuple{SplatFirst})) == IndicesInfo{2, ((1,2),), ((1, 2),)}()
312+
@test @inferred(IndicesInfo{2}(Tuple{SplatFirst})) isa IndicesInfo{2, ((1,2),), ((1, 2),)}
313313

314-
@test @inferred(IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))) ==
315-
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()
314+
@test @inferred(IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))) isa
315+
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}
316316

317-
@test @inferred(IndicesInfo{10}(Tuple{Vararg{Int,10}})) ==
318-
IndicesInfo{10, (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)}()
317+
@test @inferred(IndicesInfo{10}(Tuple{Vararg{Int,10}})) isa
318+
IndicesInfo{10, (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)}
319319

320-
@test @inferred(IndicesInfo{10}(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)))) ==
321-
IndicesInfo{10, (1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0)}()
320+
@test @inferred(IndicesInfo{10}(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)))) isa
321+
IndicesInfo{10, (1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0)}
322322

323-
@test @inferred(IndicesInfo{10}(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)))) ==
324-
IndicesInfo{10, ((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0)}()
323+
@test @inferred(IndicesInfo{10}(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)))) isa
324+
IndicesInfo{10, ((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0)}
325325

326-
@test @inferred(IndicesInfo{10}(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))))) ==
327-
IndicesInfo{10, (1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0)}()
326+
@test @inferred(IndicesInfo{10}(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))))) isa
327+
IndicesInfo{10, (1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0)}
328328
end

0 commit comments

Comments
 (0)