@@ -613,29 +613,110 @@ function Base.setindex!(a::Array, d::DArray,
613613 return a
614614end
615615
616+ # Similar to Base.indexin, but just create a logical mask
617+ indexin_mask (a, b:: Number ) = a .== b
618+ indexin_mask (a, r:: Range{Int} ) = [i in r for i in a]
619+ indexin_mask (a, b:: AbstractArray{Int} ) = indexin_mask (a, IntSet (b))
620+ indexin_mask (a, b:: AbstractArray ) = indexin_mask (a, Set (b))
621+ indexin_mask (a, b) = [i in b for i in a]
622+
623+ import Base: tail
624+ # Given a tuple of indices and a tuple of masks, restrict the indices to the
625+ # valid regions. This is, effectively, reversing Base.setindex_shape_check.
626+ # We can't just use indexing into MergedIndices here because getindex is much
627+ # pickier about singleton dimensions than setindex! is.
628+ restrict_indices (:: Tuple{} , :: Tuple{} ) = ()
629+ function restrict_indices (a:: Tuple{Any, Vararg{Any}} , b:: Tuple{Any, Vararg{Any}} )
630+ if (length (a[1 ]) == length (b[1 ]) == 1 ) || (length (a[1 ]) > 1 && length (b[1 ]) > 1 )
631+ (vec (a[1 ])[vec (b[1 ])], restrict_indices (tail (a), tail (b))... )
632+ elseif length (a[1 ]) == 1
633+ (a[1 ], restrict_indices (tail (a), b))
634+ elseif length (b[1 ]) == 1 && b[1 ][1 ]
635+ restrict_indices (a, tail (b))
636+ else
637+ throw (DimensionMismatch (" this should be caught by setindex_shape_check; please submit an issue" ))
638+ end
639+ end
640+ # The final indices are funky - they're allowed to accumulate together.
641+ # Too many masks is an easy fix -- just use the outer product to merge them:
642+ function restrict_indices (a:: Tuple{Any} , b:: Tuple{Any, Any, Vararg{Any}} )
643+ restrict_indices (a, (map (Bool, vec (vec (b[1 ])* vec (b[2 ])' )), tail (tail (b))... ))
644+ end
645+ # But too many indices is much harder; this will require merging the indices
646+ # in `a` before applying the final mask in `b`.
647+ function restrict_indices (a:: Tuple{Any, Any, Vararg{Any}} , b:: Tuple{Any} )
648+ if length (a[1 ]) == 1
649+ (a[1 ], restrict_indices (tail (a), b))
650+ else
651+ # When one mask spans multiple indices, we need to merge the indices
652+ # together. At this point, we can just use indexing to merge them since
653+ # there's no longer special handling of singleton dimensions
654+ (view (MergedIndices (a, map (length, a)), b[1 ]),)
655+ end
656+ end
657+
658+ immutable MergedIndices{T,N} <: AbstractArray{CartesianIndex{N}, N}
659+ indices:: T
660+ sz:: NTuple{N,Int}
661+ end
662+ Base. size (M:: MergedIndices ) = M. sz
663+ Base. getindex {_,N} (M:: MergedIndices{_,N} , I:: Vararg{Int, N} ) = CartesianIndex (map (getindex, M. indices, I))
664+ # Boundschecking for using MergedIndices as an array index. This is overly
665+ # strict -- even for SubArrays of ReshapedIndices, we require that the entire
666+ # parent array's indices are valid. In this usage, it is just fine... and is a
667+ # huge optimization over exact bounds checking.
668+ typealias ReshapedMergedIndices{T,N,M<: MergedIndices } Base. ReshapedArray{T,N,M}
669+ typealias SubMergedIndices{T,N,M<: Union{MergedIndices, ReshapedMergedIndices} } SubArray{T,N,M}
670+ typealias MergedIndicesOrSub Union{MergedIndices, SubMergedIndices}
671+ import Base: _chkbnds
672+ # Ambiguity with linear indexing:
673+ @inline _chkbnds (A:: AbstractVector , checked:: NTuple{1,Bool} , I:: MergedIndicesOrSub ) = _chkbnds (A, checked, parent (parent (I)). indices... )
674+ @inline _chkbnds (A:: AbstractArray , checked:: NTuple{1,Bool} , I:: MergedIndicesOrSub ) = _chkbnds (A, checked, parent (parent (I)). indices... )
675+ # Generic bounds checking
676+ @inline _chkbnds {T,N} (A:: AbstractArray{T,N} , checked:: NTuple{N,Bool} , I1:: MergedIndicesOrSub , I... ) = _chkbnds (A, checked, parent (parent (I1)). indices... , I... )
677+ @inline _chkbnds {T,N,M} (A:: AbstractArray{T,N} , checked:: NTuple{M,Bool} , I1:: MergedIndicesOrSub , I... ) = _chkbnds (A, checked, parent (parent (I1)). indices... , I... )
678+
679+ # The tricky thing here is that we want to optimize the accesses into the
680+ # distributed array, but in doing so, we lose track of which indices in I we
681+ # should be using.
682+ #
683+ # I’ve come to the conclusion that the function is utterly insane.
684+ # There are *6* flavors of indices with four different reference points:
685+ # 1. Find the indices of each portion of the DArray.
686+ # 2. Find the valid subset of indices for the SubArray into that portion.
687+ # 3. Find the portion of the `I` indices that should be used when you access the
688+ # `K` indices in the subarray. This guy is nasty. It’s totally backwards
689+ # from all other arrays, wherein we simply iterate over the source array’s
690+ # elements. You need to *both* know which elements in `J` were skipped
691+ # (`indexin_mask`) and which dimensions should match up (`restrict_indices`)
692+ # 4. If `K` doesn’t correspond to an entire chunk, reinterpret `K` in terms of
693+ # the local portion of the source array
616694function Base. setindex! (a:: Array , s:: SubDArray ,
617695 I:: Union{UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}} ...)
696+ Base. setindex_shape_check (s, Base. index_lengths (a, I... )... )
618697 n = length (I)
619698 d = s. parent
620- J = s. indexes
699+ J = Base . decolon (d, s. indexes... )
621700 if length (J) < n
701+ # TODO : this failsafe only works sometimes; the proper solution is to
702+ # implement `restrict_indices` to merge the indices above.
622703 a[I... ] = convert (Array,s)
623704 return a
624705 end
625- offs = [isa (J[i],Int) ? J[i]- 1 : first (J[i])- 1 for i= 1 : n]
626706 @sync for i = 1 : length (d. pids)
627- K_c = Any[ d. indexes[i] . .. ]
628- K = [ intersect (J[j], K_c[j]) for j = 1 : n ]
707+ K_c = d. indexes[i]
708+ K = map (intersect, J, K_c)
629709 if ! any (isempty, K)
630- idxs = [ I[j][K[j]- offs[j]] for j= 1 : n ]
710+ K_mask = map (indexin_mask, J, K_c)
711+ idxs = restrict_indices (Base. decolon (a, I... ), K_mask)
631712 if isequal (K, K_c)
632713 # whole chunk
633714 @async a[idxs... ] = chunk (d, i)
634715 else
635716 # partial chunk
636717 @async a[idxs... ] =
637718 remotecall_fetch (d. pids[i]) do
638- sub (localpart (d), [K[j]- first (K_c[j])+ 1 for j= 1 : n ]. .. )
719+ view (localpart (d), [K[j]- first (K_c[j])+ 1 for j= 1 : length (J) ]. .. )
639720 end
640721 end
641722 end
0 commit comments