Skip to content

Commit 93dde6c

Browse files
committed
code clean-up
1 parent 8d33937 commit 93dde6c

File tree

5 files changed

+43
-53
lines changed

5 files changed

+43
-53
lines changed

lib/ArrayInterfaceCore/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ArrayInterfaceCore
22
using ArrayInterfaceCore: zeromatrix
33
import ArrayInterfaceCore: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance,
4-
parent_type, zeromatrix, IndicesInfo, indices_to_dims
4+
parent_type, zeromatrix, IndicesInfo
55
using Base: setindex
66
using LinearAlgebra
77
using Random

src/axes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ end
116116
axes(A::VecAdjTrans) = (SOneTo{1}(), axes(parent(A), 1))
117117

118118
@inline axes(x::SubArray) = flatten_tuples(map(Base.Fix1(_sub_axes, x), sub_axes_map(typeof(x))))
119-
@inline _sub_axes(x::SubArray, ::Pair{StaticSymbol{:parent},StaticInt{s}}) where {s} = StaticInt(1):StaticInt(s)
119+
@inline _sub_axes(x::SubArray, axis::SOneTo) = axis
120120
_sub_axes(x::SubArray, ::StaticInt{index}) where {index} = axes(getfield(x.indices, index))
121121

122122
@inline axes(A, dim) = _axes(A, to_dims(A, dim))

src/dimensions.jl

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@ function _init_dimsmap(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
55
ntuple(i -> static(getfield(pdims, i)), length(pdims)),
66
ntuple(i -> static(getfield(cdims, i)), length(pdims))
77
end
8-
# TODO move these
9-
Static.static(::Colon) = (:)
10-
Static.static(::Nothing) = nothing
118

129
"""
1310
to_parent_dims(::Type{T}) -> Tuple{Vararg{Union{StaticInt,Tuple{Vararg{StaticInt}}}}}
@@ -22,17 +19,10 @@ to_parent_dims(@nospecialize x) = to_parent_dims(typeof(x))
2219
@inline function to_parent_dims(@nospecialize T::Type{<:SubArray})
2320
to_parent_dims(IndicesInfo{ndims(parent_type(T))}(fieldtype(T, :indices)))
2421
end
25-
function to_parent_dims(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
26-
flatten_tuples(ntuple(length(cdims)) do i
27-
cdim_i = getfield(cdims, i)
28-
if cdim_i isa Tuple
29-
pdim_i = static(getfield(pdims, i))
30-
ntuple(Compat.Returns(pdim_i), length(cdim_i))
31-
else
32-
cdim_i === 0 ? () : (static(getfield(pdims, i)),)
33-
end
34-
end)
35-
end
22+
to_parent_dims(info::IndicesInfo) = flatten_tuples(map(_to_pdim, map_indices_info(info)))
23+
_to_pdim(::Tuple{StaticInt,Any,StaticInt{0}}) = ()
24+
_to_pdim(x::Tuple{StaticInt,Any,StaticInt{cdim}}) where {cdim} = getfield(x, 2)
25+
_to_pdim(x::Tuple{StaticInt,Any,Tuple}) = (ntuple(Compat.Returns(getfield(x, 2)), length(getfield(x, 3))),)
3626
to_parent_dims(@nospecialize T::Type{<:MatAdjTrans}) = (StaticInt(2), StaticInt(1))
3727
to_parent_dims(@nospecialize T::Type{<:PermutedDimsArray}) = getfield(_permdims(T), 1)
3828

@@ -46,30 +36,33 @@ end
4636

4737
# Base will sometomes demote statically known slices in `SubArray` to `OneTo{Int}` so we
4838
# provide the parent mapping to check for static size info
49-
sub_axes_map(@nospecialize(T::Type{<:SubArray})) = _sub_axes_map(T, IndicesInfo(T))
50-
function _sub_axes_map(@nospecialize(T::Type{<:SubArray}), ::IndicesInfo{N,pdims}) where {N,pdims}
51-
ntuple(length(pdims)) do i
52-
if fieldtype(fieldtype(T, :indices), i) <: Base.Slice{OneTo{Int}}
53-
sz = known_size(parent_type(T), getfield(pdims, i))
54-
sz === nothing ? StaticInt(i) : StaticSymbol(:parent) => StaticInt(sz)
55-
else
56-
StaticInt(i)
57-
end
39+
function sub_axes_map(@nospecialize(T::Type{<:SubArray}))
40+
map(Base.Fix1(_sub_axis_map, T), map_indices_info(IndicesInfo(T)))
41+
end
42+
function _sub_axis_map(@nospecialize(T::Type{<:SubArray}), x::Tuple{StaticInt{index},Any,Any}) where {index}
43+
if fieldtype(fieldtype(T, :indices), index) <: Base.Slice{OneTo{Int}}
44+
sz = known_size(parent_type(T), getfield(x, 2))
45+
return sz === nothing ? StaticInt(index) : StaticInt(1):StaticInt(sz)
46+
else
47+
return StaticInt(index)
5848
end
5949
end
6050

61-
sub_dimnames_map(@nospecialize T::Type{<:SubArray}) = _sub_dimnames_map(IndicesInfo(T))
62-
function _sub_dimnames_map(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
63-
ntuple(length(pdims)) do i
64-
cdim_i = getfield(cdims, i)
65-
pdim_i = getfield(pdims, i)
66-
if cdim_i === 0
67-
StaticSymbol(:underscore) => StaticInt(0)
68-
elseif pdim_i isa Int && cdim_i isa Int
69-
StaticInt(pdim_i)
70-
else
71-
StaticSymbol(:underscore) => length(cdim_i)
72-
end
51+
function map_indices_info(::IndicesInfo{N,pdims,cdims}) where {N,pdims,cdims}
52+
ntuple(i -> (static(i), static(getfield(pdims, i)), static(getfield(cdims, i))), length(pdims))
53+
end
54+
function sub_dimnames_map(dnames::Tuple, imap::Tuple)
55+
flatten_tuples(map(Base.Fix1(_to_dimname, dnames), imap))
56+
end
57+
@inline function _to_dimname(dnames::Tuple, x::Tuple{StaticInt,PD,CD}) where {PD,CD}
58+
if CD <: StaticInt{0}
59+
return ()
60+
elseif CD <: Tuple
61+
return ntuple(Compat.Returns(static(:_)), StaticInt(known_length(CD)))
62+
elseif PD <: StaticInt{0} || PD <: Tuple
63+
return static(:_)
64+
else
65+
return getfield(dnames, known(PD))
7366
end
7467
end
7568

@@ -122,11 +115,7 @@ function known_dimnames(@nospecialize T::Type{<:Union{MatAdjTrans,PermutedDimsAr
122115
end
123116

124117
function known_dimnames(@nospecialize T::Type{<:SubArray})
125-
flatten_tuples(map(Base.Fix1(_known_sub_dimname, known_dimnames(parent_type(T))), sub_dimnames_map(T)))
126-
end
127-
_known_sub_dimname(dn::Tuple, ::StaticInt{dim}) where {dim} = getfield(dn, dim)
128-
function _known_sub_dimname(::Tuple, ::Pair{StaticSymbol{:underscore},StaticInt{N}}) where {N}
129-
ntuple(Compat.Returns(:_), StaticInt(N))
118+
dynamic(sub_dimnames_map(known_dimnames(parent_type(T)), map_indices_info(IndicesInfo(T))))
130119
end
131120

132121
function known_dimnames(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
@@ -184,11 +173,7 @@ have a name.
184173
end
185174

186175
function dimnames(x::SubArray)
187-
flatten_tuples(map(Base.Fix1(_sub_dimname, dimnames(parent(x))), sub_dimnames_map(typeof(x))))
188-
end
189-
_sub_dimname(dn::Tuple, ::StaticInt{dim}) where {dim} = getfield(dn, dim)
190-
function _sub_dimname(::Tuple, ::Pair{StaticSymbol{:underscore},StaticInt{N}}) where {N}
191-
ntuple(Compat.Returns(static(:_)), StaticInt(N))
176+
sub_dimnames_map(dimnames(parent(x)), map_indices_info(IndicesInfo(typeof(x))))
192177
end
193178

194179
dimnames(x::VecAdjTrans) = (static(:_), getfield(dimnames(parent(x)), 1))

src/size.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ _maybe_size(::Base.HasShape{N}, a::A) where {N,A} = map(length, axes(a))
2929
_maybe_size(::Base.HasLength, a::A) where {A} = (length(a),)
3030

3131
@inline size(x::SubArray) = flatten_tuples(map(Base.Fix1(_sub_size, x), sub_axes_map(typeof(x))))
32-
@inline _sub_size(::SubArray, ::Pair{StaticSymbol{:parent},StaticInt{s}}) where {s} = StaticInt(s)
32+
@inline _sub_size(::SubArray, ::SOneTo{S}) where {S} = StaticInt(S)
3333
_sub_size(x::SubArray, ::StaticInt{index}) where {index} = size(getfield(x.indices, index))
3434

3535
@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
@@ -150,9 +150,13 @@ end
150150
ntuple(i -> known_length(I.parameters[i]), Val(ndims(T)))
151151
end
152152

153-
@inline known_size(@nospecialize T::Type{<:SubArray}) = flatten_tuples(map(Base.Fix1(_known_sub_size, T), sub_axes_map(T)))
154-
_known_sub_size(@nospecialize(T::Type{<:SubArray}), ::Pair{StaticSymbol{:parent},StaticInt{s}}) where {s} = s
155-
_known_sub_size(@nospecialize(T::Type{<:SubArray}), ::StaticInt{index}) where {index}= known_size(fieldtype(fieldtype(T, :indices), index))
153+
@inline function known_size(@nospecialize T::Type{<:SubArray})
154+
flatten_tuples(map(Base.Fix1(_known_sub_size, T), sub_axes_map(T)))
155+
end
156+
_known_sub_size(@nospecialize(T::Type{<:SubArray}), ::SOneTo{S}) where {S} = S
157+
function _known_sub_size(@nospecialize(T::Type{<:SubArray}), ::StaticInt{index}) where {index}
158+
known_size(fieldtype(fieldtype(T, :indices), index))
159+
end
156160

157161
# 1. `Zip` doesn't check that its collections are compatible (same size) at construction,
158162
# but we assume as much b/c otherwise it will error while iterating. So we promote to the

src/stridelayout.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,8 @@ _contiguous_batch_size(::StaticInt{-1}, ::R) where {R<:Tuple} = -One()
324324

325325
contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = Zero()
326326
contiguous_batch_size(::Type{BitArray{N}}) where {N} = Zero()
327-
contiguous_batch_size(::Type{<:AbstractRange}) = Zero()
328-
contiguous_batch_size(::Type{<:Tuple}) = Zero()
327+
contiguous_batch_size(@nospecialize T::Type{<:AbstractRange}) = Zero()
328+
contiguous_batch_size(@nospecialize T::Type{<:Tuple}) = Zero()
329329
@inline function contiguous_batch_size(@nospecialize T::Type{<:Union{PermutedDimsArray,Transpose,Adjoint}})
330330
contiguous_batch_size(parent_type(T))
331331
end
@@ -413,6 +413,7 @@ end
413413
return _dense_dims(T, dd, Val(stride_rank(parent_type(T))))
414414
end
415415
end
416+
416417
@generated function _dense_dims(
417418
::Type{S},
418419
::D,

0 commit comments

Comments
 (0)