Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
ArrayInterfaceCore = "0.1.3"
Compat = "3, 4"
IfElse = "0.1"
Static = "0.6"
Static = "0.7"
julia = "1.6"

[extras]
Expand Down
4 changes: 2 additions & 2 deletions lib/ArrayInterfaceOffsetArrays/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterfaceOffsetArrays"
uuid = "015c0d05-e682-4f19-8f0a-679ce4c54826"
version = "0.1.5"
version = "0.1.6"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -10,7 +10,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
[compat]
ArrayInterface = "5, 6"
OffsetArrays = "1.11"
Static = "0.6"
Static = "0.7"
julia = "1.6"

[extras]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ function _offset_axis_type(::Type{T}, dim::StaticInt{D}) where {T,D}
OffsetArrays.IdOffsetRange{Int,ArrayInterface.axes_types(T, dim)}
end
function ArrayInterface.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
Static.eachop_tuple(_offset_axis_type, Static.nstatic(Val(ndims(T))), ArrayInterface.parent_type(T))
Static.eachop_tuple(
_offset_axis_type,
ntuple(static, StaticInt(ndims(T))),
ArrayInterface.parent_type(T)
)
end
ArrayInterface.strides(A::OffsetArray) = ArrayInterface.strides(parent(A))
function ArrayInterface.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}
Expand Down
4 changes: 2 additions & 2 deletions lib/ArrayInterfaceStaticArrays/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterfaceStaticArrays"
uuid = "b0d46f97-bff5-4637-a19a-dd75974142cd"
version = "0.1.2"
version = "0.1.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -12,7 +12,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[compat]
Adapt = "3"
ArrayInterface = "6"
Static = "0.6"
Static = "0.7"
StaticArrays = "1.2.5, 1.3, 1.4"
julia = "1.6"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ ArrayInterface.device(::Type{<:StaticArrays.MArray}) = ArrayInterface.CPUPointer
ArrayInterface.device(::Type{<:StaticArrays.SArray}) = ArrayInterface.CPUTuple()
ArrayInterface.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
ArrayInterface.contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
ArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}} = Static.nstatic(Val(N))
function ArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}}
ntuple(static, StaticInt(N))
end
function ArrayInterface.dense_dims(::Type{<:StaticArray{S,T,N}}) where {S,T,N}
ArrayInterface._all_dense(Val(N))
end
Expand Down
7 changes: 4 additions & 3 deletions src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end
end

function axes_types(::Type{T}) where {T<:ReinterpretArray}
eachop_tuple(_non_reshaped_axis_type, nstatic(Val(ndims(T))), T)
eachop_tuple(_non_reshaped_axis_type, ntuple(static, StaticInt(ndims(T))), T)
end

function _non_reshaped_axis_type(::Type{A}, d::StaticInt{D}) where {A,D}
Expand Down Expand Up @@ -146,7 +146,7 @@ function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N
return merge_tuple_type(Tuple{SOneTo{div(sizeof(S), sizeof(T))}}, axes_types(parent_type(A)))
elseif sizeof(S) < sizeof(T)
P = parent_type(A)
return eachop_tuple(field_type, tail(nstatic(Val(ndims(P)))), axes_types(P))
return eachop_tuple(field_type, tail(ntuple(static, StaticInt(ndims(P)))), axes_types(P))
else
return axes_types(parent_type(A))
end
Expand Down Expand Up @@ -241,9 +241,10 @@ Base.axes1(x::LazyAxis) = x
Base.axes(x::Slice{<:LazyAxis}) = (Base.axes1(x),)
# assuming that lazy loaded params like dynamic length from `size(::Array, dim)` are going
# be used again later with `Slice{LazyAxis}`, we quickly load indices

Base.axes1(x::Slice{LazyAxis{N,A}}) where {N,A} = indices(getfield(x.indices, :parent), StaticInt(N))
Base.axes1(x::Slice{LazyAxis{:,A}}) where {A} = indices(getfield(x.indices, :parent))
Base.to_shape(x::LazyAxis) = length(x)
Base.to_shape(x::LazyAxis) = Base.length(x)

@propagate_inbounds function Base.getindex(x::LazyAxis, i::CanonicalInt)
@boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i))
Expand Down
31 changes: 7 additions & 24 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,6 @@ function throw_dim_error(@nospecialize(x), @nospecialize(dim))
throw(DimensionMismatch("$x does not have dimension corresponding to $dim"))
end

@propagate_inbounds function _promote_shape(a::Tuple{A,Vararg{Any}}, b::Tuple{B,Vararg{Any}}) where {A,B}
(_try_static(getfield(a, 1), getfield(b, 1)), _promote_shape(tail(a), tail(b))...)
end
_promote_shape(::Tuple{}, ::Tuple{}) = ()
@propagate_inbounds function _promote_shape(::Tuple{}, b::Tuple{B}) where {B}
(_try_static(static(1), getfield(b, 1)),)
end
@propagate_inbounds function _promote_shape(a::Tuple{A}, ::Tuple{}) where {A}
(_try_static(static(1), getfield(a, 1)),)
end
@propagate_inbounds function Base.promote_shape(a::Tuple{Vararg{CanonicalInt}}, b::Tuple{Vararg{CanonicalInt}})
_promote_shape(a, b)
end

#julia> @btime ArrayInterfaceCore.is_increasing(ArrayInterfaceCore.nstatic(Val(10)))
# 0.045 ns (0 allocations: 0 bytes)
#ArrayInterfaceCore.True()
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y},Vararg}) where {X, Y}
if X <= Y
return is_increasing(tail(perm))
Expand All @@ -44,7 +27,7 @@ is_increasing(::Tuple{}) = True()
Returns the mapping from parent dimensions to child dimensions.
"""
from_parent_dims(x) = from_parent_dims(typeof(x))
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
from_parent_dims(::Type{T}) where {T} = ntuple(static, StaticInt(ndims(T)))
from_parent_dims(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2),)
from_parent_dims(::Type{T}) where {T<:MatAdjTrans} = (StaticInt(2), One())
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(I)
Expand All @@ -65,11 +48,11 @@ end
from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = static(Val(I))
function from_parent_dims(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
if !IsReshaped || sizeof(S) === sizeof(T)
return nstatic(Val(ndims(A)))
return ntuple(static, StaticInt(ndims(A)))
elseif sizeof(S) > sizeof(T)
return tail(nstatic(Val(ndims(A) + 1)))
return tail(ntuple(static, StaticInt(ndims(A) + 1)))
else # sizeof(S) < sizeof(T)
return (Zero(), nstatic(Val(N))...)
return (Zero(), ntuple(static, StaticInt(N))...)
end
end

Expand Down Expand Up @@ -100,7 +83,7 @@ end
Returns the mapping from child dimensions to parent dimensions.
"""
to_parent_dims(x) = to_parent_dims(typeof(x))
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
to_parent_dims(::Type{T}) where {T} = ntuple(static, StaticInt(ndims(T)))
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = static(Val(I))
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(I)
Expand All @@ -117,7 +100,7 @@ to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(I)
out
end
function to_parent_dims(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
pdims = nstatic(Val(ndims(A)))
pdims = ntuple(static, StaticInt(ndims(A)))
if !IsReshaped || sizeof(S) === sizeof(T)
return pdims
elseif sizeof(S) > sizeof(T)
Expand Down Expand Up @@ -280,7 +263,7 @@ to_dims(x, @nospecialize(dim::CanonicalInt)) = dim
to_dims(x, dim::Integer) = Int(dim)
to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim)
function to_dims(x, dims::Tuple{Vararg{Any,N}}) where {N}
eachop(_to_dims, nstatic(Val(N)), dimnames(x), dims)
eachop(_to_dims, ntuple(static, StaticInt(N)), dimnames(x), dims)
end
@inline _to_dims(x::Tuple, d::Tuple, n::StaticInt{N}) where {N} = _to_dim(x, getfield(d, N))
@inline function _to_dim(x::Tuple, d::Union{Symbol,StaticSymbol})
Expand Down
4 changes: 2 additions & 2 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ _ints2range_front(::Val{0}, ind, inds...) = ()
_ints2range_front(::Val{0}) = ()
# get output shape with given indices
_output_shape(::CanonicalInt, inds...) = _output_shape(inds...)
_output_shape(ind::AbstractRange, inds...) = (length(ind), _output_shape(inds...)...)
_output_shape(ind::AbstractRange, inds...) = (Base.length(ind), _output_shape(inds...)...)
_output_shape(::CanonicalInt) = ()
_output_shape(x::AbstractRange) = (length(x),)
_output_shape(x::AbstractRange) = (Base.length(x),)
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()
return Base._getindex(IndexStyle(A), A, inds...)
Expand Down
6 changes: 3 additions & 3 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ function Base.iterate(::SOneTo{n}, s::Int) where {n}
end
end

Base.to_shape(x::OptionallyStaticRange) = length(x)
Base.to_shape(x::Slice{T}) where {T<:OptionallyStaticRange} = length(x)
Base.to_shape(x::OptionallyStaticRange) = Base.length(x)
Base.to_shape(x::Slice{T}) where {T<:OptionallyStaticRange} = Base.length(x)
Base.axes(S::Slice{<:OptionallyStaticUnitRange{One}}) = (S.indices,)
Base.axes(S::Slice{<:OptionallyStaticRange}) = (Base.IdentityUnitRange(S.indices),)

Expand Down Expand Up @@ -374,7 +374,7 @@ end
Returns valid indices for each array in `x` along dimension `dim`
"""
@propagate_inbounds function indices(x::Tuple, dim)
inds = map(x_i -> indices(x_i, dim), x)
inds = map(Base.Fix2(indices, dim), x)
return reduce_tup(_pick_range, inds)
end

Expand Down
8 changes: 4 additions & 4 deletions src/size.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ size(x::Iterators.Enumerate) = size(getfield(x, :itr))
size(x::Iterators.Accumulate) = size(getfield(x, :itr))
size(x::Iterators.Pairs) = size(getfield(x, :itr))
@inline function size(x::Iterators.ProductIterator)
eachop(_sub_size, nstatic(Val(ndims(x))), getfield(x, :iterators))
eachop(_sub_size, ntuple(static, StaticInt(ndims(x))), getfield(x, :iterators))
end

size(a, dim) = size(a, to_dims(a, dim))
Expand Down Expand Up @@ -100,7 +100,7 @@ known_size(x) = known_size(typeof(x))
end
end
function _maybe_known_size(::Base.HasShape{N}, ::Type{T}) where {N,T}
eachop(_known_size, nstatic(Val(N)), axes_types(T))
eachop(_known_size, ntuple(static, StaticInt(N)), axes_types(T))
end
_maybe_known_size(::Base.IteratorSize, ::Type{T}) where {T} = (known_length(T),)
function known_size(::Type{T}) where {T<:AbstractRange}
Expand All @@ -113,7 +113,7 @@ known_size(::Type{<:Iterators.Enumerate{I}}) where {I} = known_size(I)
known_size(::Type{<:Iterators.Accumulate{<:Any,I}}) where {I} = known_size(I)
known_size(::Type{<:Iterators.Pairs{<:Any,<:Any,I}}) where {I} = known_size(I)
@inline function known_size(::Type{<:Iterators.ProductIterator{T}}) where {T}
eachop(_known_size, nstatic(Val(known_length(T))), T)
eachop(_known_size, ntuple(static, StaticInt(known_length(T))), T)
end

# 1. `Zip` doesn't check that its collections are compatible (same size) at construction,
Expand All @@ -123,7 +123,7 @@ end
# trailing dimensions (which must be of size 1), to `static(1)`. We want to stick to
# `Nothing` and `Int` types, so we do one last pass to ensure everything is dynamic
@inline function known_size(::Type{<:Iterators.Zip{T}}) where {T}
dynamic(reduce_tup(_promote_shape, eachop(_unzip_size, nstatic(Val(known_length(T))), T)))
dynamic(reduce_tup(Static._promote_shape, eachop(_unzip_size, ntuple(static, StaticInt(known_length(T))), T)))
end
_unzip_size(::Type{T}, n::StaticInt{N}) where {T,N} = known_size(field_type(T, n))
_known_size(::Type{T}, dim::StaticInt) where {T} = known_length(field_type(T, dim))
Expand Down
20 changes: 10 additions & 10 deletions src/stridelayout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ stride_preserving_index(::Type{T}) where {T<:AbstractRange} = True()
stride_preserving_index(::Type{T}) where {T<:Int} = True()
stride_preserving_index(::Type{T}) where {T} = False()
function stride_preserving_index(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
if all(eachop(_stride_preserving_index, nstatic(Val(N)), T))
if all(eachop(_stride_preserving_index, ntuple(static, StaticInt(N)), T))
return True()
else
return False()
Expand Down Expand Up @@ -54,7 +54,7 @@ end

known_offsets(x) = known_offsets(typeof(x))
function known_offsets(::Type{T}) where {T}
return eachop(_known_offsets, nstatic(Val(ndims(T))), axes_types(T))
return eachop(_known_offsets, ntuple(static, StaticInt(ndims(T))), axes_types(T))
end
_known_offsets(::Type{T}, dim::StaticInt) where {T} = known_first(field_type(T, dim))

Expand All @@ -71,7 +71,7 @@ For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1)
@inline offsets(x, i) = static_first(indices(x, i))
offsets(::Tuple) = (One(),)
offsets(x::StrideIndex) = getfield(x, :offsets)
offsets(x) = eachop(_offsets, nstatic(Val(ndims(x))), x)
offsets(x) = eachop(_offsets, ntuple(static, StaticInt(ndims(x))), x)
function _offsets(x::X, dim::StaticInt{D}) where {X,D}
start = known_first(axes_types(X, dim))
if start === nothing
Expand Down Expand Up @@ -216,7 +216,7 @@ end
contiguous_axis_indicator(::A) where {A<:AbstractArray} = contiguous_axis_indicator(A)
contiguous_axis_indicator(::Nothing, ::Val) = nothing
function contiguous_axis_indicator(c::StaticInt{N}, dim::Val{D}) where {N,D}
return map(i -> eq(c, i), nstatic(dim))
map(eq(c), ntuple(static, dim))
end

function rank_to_sortperm(R::Tuple{Vararg{StaticInt,N}}) where {N}
Expand All @@ -233,8 +233,8 @@ stride_rank(x) = stride_rank(typeof(x))
function stride_rank(::Type{T}) where {T}
is_forwarding_wrapper(T) ? stride_rank(parent_type(T)) : nothing
end
stride_rank(::Type{<:DenseArray{T,N}}) where {T,N} = nstatic(Val(N))
stride_rank(::Type{BitArray{N}}) where {N} = nstatic(Val(N))
stride_rank(::Type{<:DenseArray{T,N}}) where {T,N} = ntuple(static, StaticInt(N))
stride_rank(::Type{BitArray{N}}) where {N} = ntuple(static, StaticInt(N))
stride_rank(::Type{<:AbstractRange}) = (One(),)
stride_rank(::Type{<:Tuple}) = (One(),)

Expand All @@ -257,7 +257,7 @@ _stride_rank(::Type{T}, r::Tuple) where {T<:SubArray} = permute(r, to_parent_dim

stride_rank(x, i) = stride_rank(x)[i]
function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}}
return nstatic(Val(N))
return ntuple(static, StaticInt(N))
end
@inline function stride_rank(::Type{A}) where {NB,NA,B<:AbstractArray{<:Any,NB},A<:Base.ReinterpretArray{<:Any,NA,<:Any,B,true}}
NA == NB ? stride_rank(B) : _stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
Expand Down Expand Up @@ -304,7 +304,7 @@ function stride_rank(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Transpose{T,
IfElse.ifelse(is_dense(A), (static(1),), nothing)
end

_reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N))
_reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = ntuple(static, StaticInt(N))
_reshaped_striderank(_, __, ___) = nothing

"""
Expand Down Expand Up @@ -466,7 +466,7 @@ function dense_dims(T::Type{<:Base.ReshapedArray})
return n_of_x(StaticInt(ndims(T)), False())
end
end

is_dense(A) = is_dense(typeof(A))
is_dense(::Type{A}) where {A} = _is_dense(dense_dims(A))
_is_dense(::Tuple{False,Vararg}) = False()
Expand Down Expand Up @@ -671,7 +671,7 @@ end
@inline function _reinterp_strides(stp::Tuple, els::StaticInt, elp::StaticInt)
if elp % els == 0
N = elp ÷ els
return map(i -> N * i, stp)
return map(Base.Fix2(*, N), stp)
else
return map(stp) do i
d, r = divrem(elp * i, els)
Expand Down
12 changes: 6 additions & 6 deletions test/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ end
@test @inferred(ArrayInterface.known_dimnames(Iterators.flatten(1:10))) === (:_,)
@test @inferred(ArrayInterface.known_dimnames(Iterators.flatten(1:10), static(1))) === :_
@test @inferred(ArrayInterface.known_dimnames(z)) === (nothing, :y)
@test @inferred(ArrayInterface.known_dimnames(reshape(x, (1, 4)))) == d
@test @inferred(ArrayInterface.known_dimnames(r1)) == d
@test @inferred(ArrayInterface.known_dimnames(r2)) == (:_, d...)
@test @inferred(ArrayInterface.known_dimnames(r3)) == Base.tail(d)
@test @inferred(ArrayInterface.known_dimnames(r4)) == d
@test @inferred(ArrayInterface.known_dimnames(w)) == d
@test @inferred(ArrayInterface.known_dimnames(reshape(x, (1, 4)))) === (:x, :y)
@test @inferred(ArrayInterface.known_dimnames(r1)) === (:x, :y)
@test @inferred(ArrayInterface.known_dimnames(r2)) === (:_, :x, :y)
@test @inferred(ArrayInterface.known_dimnames(r3)) === (:y,)
@test @inferred(ArrayInterface.known_dimnames(r4)) === (:x, :y)
@test @inferred(ArrayInterface.known_dimnames(w)) === (:x, :y)
@test @inferred(ArrayInterface.known_dimnames(reshape(x, :))) === (:_,)
@test @inferred(ArrayInterface.known_dimnames(view(x, :, 1)')) === (:_, :x)
end
Expand Down