Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/ArrayInterfaceCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterfaceCore"
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
version = "0.1.13"
version = "0.1.14"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
168 changes: 151 additions & 17 deletions lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using SuiteSparse
using Base: @assume_effects
else
macro assume_effects(_, ex)
Base.@pure ex
:(Base.@pure $(ex))
end
end

Expand All @@ -22,6 +22,72 @@ const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}
const UpTri{T,M} = Union{UpperTriangular{T,M},UnitUpperTriangular{T,M}}
const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}

"""
ArrayInterfaceCore.map_tuple_type(f, T::Type{<:Tuple})

Returns tuple where each field corresponds to the field type of `T` modified by the function `f`.

# Examples

```julia
julia> ArrayInterfaceCore.map_tuple_type(sqrt, Tuple{1,4,16})
(1.0, 2.0, 4.0)

```
"""
function map_tuple_type(f::F, ::Type{T}) where {F,T<:Tuple}
if @generated
t = Expr(:tuple)
for i in 1:fieldcount(T)
push!(t.args, :(f($(fieldtype(T, i)))))
end
Expr(:block, Expr(:meta, :inline), t)
else
Tuple(f(fieldtype(T, i)) for i in 1:fieldcount(T))
end
end

"""
ArrayInterfaceCore.flatten_tuples(t::Tuple) -> Tuple

Flattens any field of `t` that is a tuple. Only direct fields of `t` may be flattened.

# Examples

```julia
julia> ArrayInterfaceCore.flatten_tuples((1, ()))
(1,)

julia> ArrayInterfaceCore.flatten_tuples((1, (2, 3)))
(1, 2, 3)

julia> ArrayInterfaceCore.flatten_tuples((1, (2, (3,))))
(1, 2, (3,))

```
"""
@inline function flatten_tuples(t::Tuple)
if @generated
texpr = Expr(:tuple)
for i in 1:fieldcount(t)
p = fieldtype(t, i)
if p <: Tuple
for j in 1:fieldcount(p)
push!(texpr.args, :(@inbounds(getfield(getfield(t, $i), $j))))
end
else
push!(texpr.args, :(@inbounds(getfield(t, $i))))
end
end
Expr(:block, Expr(:meta, :inline), texpr)
else
_flatten(t)
end
end
_flatten(::Tuple{}) = ()
@inline _flatten(t::Tuple{Any,Vararg{Any}}) = (getfield(t, 1), _flatten(Base.tail(t))...)
@inline _flatten(t::Tuple{Tuple,Vararg{Any}}) = (getfield(t, 1)..., _flatten(Base.tail(t))...)

"""
parent_type(::Type{T}) -> Type

Expand Down Expand Up @@ -591,32 +657,100 @@ indexing with an instance of `I`.
"""
ndims_shape(T::DataType) = ndims_index(T)
ndims_shape(::Type{Colon}) = 1
ndims_shape(T::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = ntuple(zero, Val{N}())
ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ntuple(one, Val{ndims(T)}())
ndims_shape(@nospecialize T::Type{<:Number}) = 0
ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T)
ndims_shape(@nospecialize T::Type{<:Union{Number,Base.AbstractCartesianIndex}}) = 0
ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1
ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T)
ndims_shape(x) = ndims_shape(typeof(x))

@assume_effects :total function _find_first_true(isi::Tuple{Vararg{Bool,N}}) where {N}
for i in 1:N
getfield(isi, i) && return i
end
return nothing
end

"""
IndicesInfo(T::Type{<:Tuple}) -> IndicesInfo{NI,NS,IS}()
IndicesInfo{N}(T::Type{<:Tuple}) -> IndicesInfo{N,NI,NS}()

Provides basic trait information for each index type in in the tuple `T`. `NI`, `NS`, and
`IS` are tuples of [`ndims_index`](@ref), [`ndims_shape`](@ref), and
[`is_splat_index`](@ref) (respectively) for each field of `T`.

# Examples

```julia
julia> using ArrayInterfaceCore: IndicesInfo

julia> IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()

```
"""
struct IndicesInfo{NI,NS,IS} end
IndicesInfo(@nospecialize x::Tuple) = IndicesInfo(typeof(x))
@generated function IndicesInfo(::Type{T}) where {T<:Tuple}
NI = Expr(:tuple)
NS = Expr(:tuple)
IS = Expr(:tuple)
for i in 1:fieldcount(T)
T_i = fieldtype(T, i)
push!(NI.args, :(ndims_index($(T_i))))
push!(NS.args, :(ndims_shape($(T_i))))
push!(IS.args, :(is_splat_index($(T_i))))
struct IndicesInfo{N,NI,NS} end
IndicesInfo(x::SubArray) = IndicesInfo{ndims(parent(x))}(typeof(x.indices))
@inline function IndicesInfo(@nospecialize T::Type{<:SubArray})
IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices))
end
function IndicesInfo{N}(@nospecialize(T::Type{<:Tuple})) where {N}
_indices_info(
Val{_find_first_true(map_tuple_type(is_splat_index, T))}(),
IndicesInfo{N,map_tuple_type(ndims_index, T),map_tuple_type(ndims_shape, T)}()
)
end
function _indices_info(::Val{nothing}, ::IndicesInfo{1,(1,),NS}) where {NS}
ns1 = getfield(NS, 1)
IndicesInfo{1,(1,), (ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
end
function _indices_info(::Val{nothing}, ::IndicesInfo{N,(1,),NS}) where {N,NS}
ns1 = getfield(NS, 1)
IndicesInfo{N,(:,),(ns1 > 1 ? ntuple(identity, ns1) : ns1,)}()
end
@inline function _indices_info(::Val{nothing}, ::IndicesInfo{N,NI,NS}) where {N,NI,NS}
if sum(NI) > N
IndicesInfo{N,_replace_trailing(N, _accum_dims(cumsum(NI), NI)), _accum_dims(cumsum(NS), NS)}()
else
IndicesInfo{N,_accum_dims(cumsum(NI), NI), _accum_dims(cumsum(NS), NS)}()
end
end
@inline function _indices_info(::Val{SI}, ::IndicesInfo{N,NI,NS}) where {N,NI,NS,SI}
nsplat = N - sum(NI)
if nsplat === 0
_indices_info(Val{nothing}(), IndicesInfo{N,NI,NS}())
else
splatmul = max(0, nsplat + 1)
_indices_info(Val{nothing}(), IndicesInfo{N,_map_splats(splatmul, SI, NI),_map_splats(splatmul, SI, NS)}())
end
end
@inline function _map_splats(nsplat::Int, splat_index::Int, dims::Tuple{Vararg{Int}})
ntuple(length(dims)) do i
i === splat_index ? (nsplat * getfield(dims, i)) : getfield(dims, i)
end
end
@inline function _replace_trailing(n::Int, dims::Tuple{Vararg{Any,N}}) where {N}
ntuple(N) do i
dim_i = getfield(dims, i)
if dim_i isa Tuple
ntuple(length(dim_i)) do j
dim_i_j = getfield(dim_i, j)
dim_i_j > n ? 0 : dim_i_j
end
else
dim_i > n ? 0 : dim_i
end
end
end
@inline function _accum_dims(csdims::NTuple{N,Int}, nd::NTuple{N,Int}) where {N}
ntuple(N) do i
nd_i = getfield(nd, i)
if nd_i === 0
0
elseif nd_i === 1
getfield(csdims, i)
else
ntuple(Base.Fix1(+, getfield(csdims, i) - nd_i), nd_i)
end
end
Expr(:block, Expr(:meta, :inline), :(IndicesInfo{$(NI),$(NS),$(IS)}()))
end

"""
Expand Down
39 changes: 36 additions & 3 deletions lib/ArrayInterfaceCore/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ArrayInterfaceCore
using ArrayInterfaceCore: zeromatrix
import ArrayInterfaceCore: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance,
parent_type, zeromatrix
parent_type, zeromatrix, IndicesInfo
using Base: setindex
using LinearAlgebra
using Random
Expand Down Expand Up @@ -271,8 +271,8 @@ end
@testset "ndims_shape" begin
@test @inferred(ArrayInterfaceCore.ndims_shape(1)) === 0
@test @inferred(ArrayInterfaceCore.ndims_shape(:)) === 1
@test @inferred(ArrayInterfaceCore.ndims_shape(CartesianIndex(1, 2))) === (0, 0)
@test @inferred(ArrayInterfaceCore.ndims_shape(CartesianIndices((2,2)))) === (1, 1)
@test @inferred(ArrayInterfaceCore.ndims_shape(CartesianIndex(1, 2))) === 0
@test @inferred(ArrayInterfaceCore.ndims_shape(CartesianIndices((2,2)))) === 2
@test @inferred(ArrayInterfaceCore.ndims_shape([1 1])) === 2
end

Expand All @@ -293,3 +293,36 @@ end
@test !ArrayInterfaceCore.indices_do_not_alias(typeof(view(fill(rand(4,4),4,4)', 2:3, 1:2)))
@test !ArrayInterfaceCore.indices_do_not_alias(typeof(view(rand(4,4)', StepRangeLen(1,0,5), 1:2)))
end

@testset "IndicesInfo" begin

struct SplatFirst end

ArrayInterfaceCore.is_splat_index(::Type{SplatFirst}) = true

@test @inferred(IndicesInfo(SubArray{Float64, 2, Vector{Float64}, Tuple{Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}, true})) ==
IndicesInfo{1,(1,),((1,2),)}()

@test @inferred(IndicesInfo{1}((Tuple{Vector{Int}}))) == IndicesInfo{1, (1,), (1,)}()

@test @inferred(IndicesInfo{2}(Tuple{Vector{Int}})) == IndicesInfo{2, (:,), (1,)}()

@test @inferred(IndicesInfo{1}(Tuple{SplatFirst})) == IndicesInfo{1, (1,), (1,)}()

@test @inferred(IndicesInfo{2}(Tuple{SplatFirst})) == IndicesInfo{2, ((1,2),), ((1, 2),)}()

@test @inferred(IndicesInfo{5}(typeof((:,[CartesianIndex(1,1),CartesianIndex(1,1)], 1, ones(Int, 2, 2), :, 1)))) ==
IndicesInfo{5, (1, (2, 3), 4, 5, 0, 0), (1, 2, 0, (3, 4), 5, 0)}()

@test @inferred(IndicesInfo{10}(Tuple{Vararg{Int,10}})) ==
IndicesInfo{10, (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)}()

@test @inferred(IndicesInfo{10}(typeof((1, CartesianIndex(2, 1), 2, CartesianIndex(1, 2), 1, CartesianIndex(2, 1), 2)))) ==
IndicesInfo{10, (1, (2, 3), 4, (5, 6), 7, (8, 9), 10), (0, 0, 0, 0, 0, 0, 0)}()

@test @inferred(IndicesInfo{10}(typeof((fill(true, 4, 4), 2, fill(true, 4, 4), 2, 1, fill(true, 4, 4), 1)))) ==
IndicesInfo{10, ((1, 2), 3, (4, 5), 6, 7, (8, 9), 10), (1, 0, 2, 0, 0, 3, 0)}()

@test @inferred(IndicesInfo{10}(typeof((1, SplatFirst(), 2, SplatFirst(), CartesianIndex(1, 1))))) ==
IndicesInfo{10, (1, (2, 3, 4, 5, 6), 7, 8, (9, 10)), (0, (1, 2, 3, 4, 5), 0, 6, 0)}()
end
5 changes: 2 additions & 3 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
issingular, isstructured, matrix_colors, restructure, lu_instance,
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo,
map_tuple_type, flatten_tuples, GetIndex

# ArrayIndex subtypes and methods
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex
Expand Down Expand Up @@ -34,8 +35,6 @@ using LinearAlgebra

import Compat

n_of_x(::StaticInt{N}, x::X) where {N,X} = ntuple(Compat.Returns(x), Val{N}())

_add1(@nospecialize x) = x + oneunit(x)
_sub1(@nospecialize x) = x - oneunit(x)

Expand Down
Loading