Skip to content

Commit dadccd9

Browse files
committed
consolidate and clean-up code
1 parent f62c8dc commit dadccd9

File tree

7 files changed

+183
-173
lines changed

7 files changed

+183
-173
lines changed

lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,15 +678,76 @@ Provides basic trait information for each index type in in the tuple `T`. `NI`,
678678
[`is_splat_index`](@ref) (respectively) for each field of `T`.
679679
"""
680680
struct IndicesInfo{NI,NS,IS} end
681-
IndicesInfo(x::Tuple) = IndicesInfo(typeof(x))
681+
IndicesInfo(x::Union{Tuple,SubArray}) = IndicesInfo(typeof(x))
682682
function IndicesInfo(@nospecialize T::Type{<:Tuple})
683683
IndicesInfo{
684684
map_tuple_type(ndims_index, T),
685685
map_tuple_type(ndims_shape, T),
686686
_find_first_true(map_tuple_type(is_splat_index, T))
687-
#findfirst(==(1), map_tuple_type(is_splat_index, T))
688687
}()
689688
end
689+
function IndicesInfo(@nospecialize T::Type{<:SubArray})
690+
indices_to_dims(IndicesInfo(fieldtype(T, :indices)), parent_type(T))
691+
end
692+
693+
@inline function indices_to_dims(@nospecialize(I::Type{<:Tuple}), @nospecialize(T::Type))
694+
indices_to_dims(IndicesInfo(I), T)
695+
end
696+
@inline function indices_to_dims(::IndicesInfo{(1,),NS,nothing}, @nospecialize(T::Type)) where {NS}
697+
ns1 = getfield(NS, 1)
698+
nsd = ns1 > 1 ? ntuple(identity, ns1) : ns1
699+
if ndims(T) === 1
700+
IndicesInfo{(1,), (nsd,), nothing}()
701+
else
702+
IndicesInfo{(:,), (nsd,), nothing}()
703+
end
704+
end
705+
@inline function indices_to_dims(::IndicesInfo{NI,NS,nothing}, @nospecialize(T::Type)) where {NI,NS}
706+
if sum(NI) > ndims(T)
707+
IndicesInfo{_replace_trailing(ndims(T), _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS), nothing}()
708+
else
709+
IndicesInfo{_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS), nothing}()
710+
end
711+
end
712+
@inline function indices_to_dims(::IndicesInfo{NI,NS,SI}, @nospecialize(T::Type)) where {NI,NS,SI}
713+
nsplat = ndims(T) - sum(NI)
714+
if nsplat === 0
715+
indices_to_dims(IndicesInfo{NI,NS,nothing}(), T)
716+
else
717+
splatmul = max(0, nsplat + 1)
718+
indices_to_dims(IndicesInfo{_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS),nothing}(), T)
719+
end
720+
end
721+
@inline function _map_splats(nsplat::Int, splat_index::Int, dims::Tuple{Vararg{Int}})
722+
ntuple(length(dims)) do i
723+
i === splat_index ? (nsplat * getfield(dims, i)) : getfield(dims, i)
724+
end
725+
end
726+
@inline function _replace_trailing(n::Int, dims::Tuple{Vararg{Any,N}}) where {N}
727+
ntuple(N) do i
728+
dim_i = getfield(dims, i)
729+
if dim_i isa Tuple
730+
ntuple(length(dim_i)) do j
731+
dim_i_j = getfield(dim_i, j)
732+
dim_i_j > n ? 0 : dim_i_j
733+
end
734+
else
735+
dim_i > n ? 0 : dim_i
736+
end
737+
end
738+
end
739+
@inline function _accum_dims(csdims::NTuple{N,Int}, nd::NTuple{N,Int}) where {N}
740+
ntuple(N) do i
741+
nd_i = getfield(nd, i)
742+
if nd_i === 0
743+
0
744+
elseif nd_i === 1
745+
getfield(csdims, i)
746+
else
747+
ntuple(Base.Fix1(+, getfield(csdims, i) - nd_i), nd_i)
748+
end
749+
end
750+
end
690751

691752
"""
692753
instances_do_not_alias(::Type{T}) -> Bool

lib/ArrayInterfaceCore/test/runtests.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ArrayInterfaceCore
22
using ArrayInterfaceCore: zeromatrix
33
import ArrayInterfaceCore: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance,
4-
parent_type, zeromatrix
4+
parent_type, zeromatrix, IndicesInfo, indices_to_dims
55
using Base: setindex
66
using LinearAlgebra
77
using Random
@@ -293,3 +293,36 @@ end
293293
@test !ArrayInterfaceCore.indices_do_not_alias(typeof(view(fill(rand(4,4),4,4)', 2:3, 1:2)))
294294
@test !ArrayInterfaceCore.indices_do_not_alias(typeof(view(rand(4,4)', StepRangeLen(1,0,5), 1:2)))
295295
end
296+
297+
@testset "indices_to_dims" begin
298+
299+
struct SplatFirst end
300+
301+
ArrayInterfaceCore.is_splat_index(::Type{SplatFirst}) = true
302+
303+
@test @inferred(IndicesInfo(SubArray{Float64, 2, Vector{Float64}, Tuple{Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}, true})) ==
304+
IndicesInfo{(1,),((1,2),),nothing}()
305+
306+
@test @inferred(indices_to_dims(Tuple{Vector{Int}}, Vector{Int})) == IndicesInfo{(1,), (1,), nothing}()
307+
308+
@test @inferred(indices_to_dims(Tuple{Vector{Int}}, Matrix{Int})) == IndicesInfo{(:,), (1,), nothing}()
309+
310+
@test @inferred(indices_to_dims(Tuple{SplatFirst}, Vector{Int})) == IndicesInfo{(1,), (1,), nothing}()
311+
312+
@test @inferred(indices_to_dims(Tuple{SplatFirst}, Matrix{Int})) == IndicesInfo{((1,2),), ((1, 2),), nothing}()
313+
314+
@test @inferred(indices_to_dims(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)), Array{Int,5})) ==
315+
IndicesInfo{(1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0), nothing}()
316+
317+
@test @inferred(indices_to_dims(Tuple{Vararg{Int,10}}, Array{Int,10})) ==
318+
IndicesInfo{(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0), nothing}()
319+
320+
@test @inferred(indices_to_dims(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)),Array{Int,10})) ==
321+
IndicesInfo{(1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0), nothing}()
322+
323+
@test @inferred(indices_to_dims(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)), Array{Int,10})) ==
324+
IndicesInfo{((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0), nothing}()
325+
326+
@test @inferred(indices_to_dims(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))), Array{Int,10})) ==
327+
IndicesInfo{(1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0), nothing}()
328+
end

src/ArrayInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
77
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
88
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo,
9-
map_tuple_type, flatten_tuples, GetIndex
9+
map_tuple_type, flatten_tuples, GetIndex, indices_to_dims
1010

1111
# ArrayIndex subtypes and methods
1212
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex

src/axes.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,10 @@ axes(A::ReshapedArray) = Base.axes(A)
114114
map(GetIndex{false}(axes(parent(x))), to_parent_dims(x))
115115
end
116116
axes(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1))
117-
@inline function axes(x::SubArray)
118-
flatten_tuples(map((i, d) -> _sub_axis(parent(x), i, d), x.indices, dimsmap(x)))
119-
end
120-
_sub_axis(x, idx, ::Tuple{StaticInt{0}}) = ()
121-
function _sub_axis(x, idx::Base.Slice{Base.OneTo{Int}}, dm::Tuple{StaticInt,Any})
122-
sz = known_size(x, getfield(dm, 2))
123-
sz === nothing ? axes(idx) : StaticInt(1):StaticInt(sz)
124-
end
125-
_sub_axis(x, idx, ::Tuple{Any,Any}) = axes(idx)
126117

118+
@inline axes(x::SubArray) = flatten_tuples(map(Base.Fix1(_sub_axes, x), sub_axes_map(typeof(x))))
119+
@inline _sub_axes(x::SubArray, ::Pair{StaticSymbol{:parent},StaticInt{s}}) where {s} = StaticInt(1):StaticInt(s)
120+
_sub_axes(x::SubArray, ::StaticInt{index}) where {index} = axes(getfield(x.indices, index))
127121

128122
@inline axes(A, dim) = _axes(A, to_dims(A, dim))
129123
@inline _axes(A, dim::Int) = dim > ndims(A) ? OneTo(1) : getfield(axes(A), dim)

0 commit comments

Comments
 (0)