Skip to content

Commit 4924364

Browse files
authored
Indexing Variable objects with a mask (#1751)
* Variable indexing with a mask This will be useful for multi-dimensional reindexing: marking masked items with -1 is exactly the convention used by pandas.Index.get_indexer(). Example usage: In [6]: variable = xr.Variable(('x',), [1, 2, 3]) In [7]: variable._getitem_with_mask([0, 1, 2, -1]) Out[7]: <xarray.Variable (x: 4)> array([ 1., 2., 3., nan]) In [8]: variable._getitem_with_mask(xr.Variable(('x', 'y'), [[0, -1], [-1, 1]]), fill_value=-99) Out[8]: <xarray.Variable (x: 2, y: 2)> array([[ 1, -99], [-99, 2]]) This uses where() so it isn't the most efficient (there is some wasted effort doing indexing, as noted in the TODOs), but the implementation is pretty clean and already works with dask. For now, I'm leaving this as private API, but let's expose it publicly in the future if we are happy with it. I would probably leave it as a Variable method since this is pretty low-level. * More tests for _getitem_with_mask & fixes for dask
1 parent cb161a1 commit 4924364

File tree

4 files changed

+314
-24
lines changed

4 files changed

+314
-24
lines changed

xarray/core/indexing.py

Lines changed: 151 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from __future__ import print_function
44
from datetime import timedelta
55
from collections import defaultdict, Hashable
6+
import functools
67
import operator
78
import numpy as np
89
import pandas as pd
910

1011
from . import nputils
1112
from . import utils
13+
from . import duck_array_ops
1214
from .pycompat import (iteritems, range, integer_types, dask_array_type,
1315
suppress)
1416
from .utils import is_dict_like
@@ -589,27 +591,23 @@ def as_indexable(array):
589591
raise TypeError('Invalid array type: {}'.format(type(array)))
590592

591593

592-
def _outer_to_numpy_indexer(key, shape):
593-
"""Convert an OuterIndexer into an indexer for NumPy.
594+
def _outer_to_vectorized_indexer(key, shape):
595+
"""Convert an OuterIndexer into an vectorized indexer.
594596
595597
Parameters
596598
----------
597599
key : tuple
598-
Outer indexing tuple to convert.
600+
Tuple from an OuterIndexer to convert.
599601
shape : tuple
600602
Shape of the array subject to the indexing.
601603
602604
Returns
603605
-------
604606
tuple
605-
Base tuple suitable for use to index a NumPy array.
607+
Tuple suitable for use to index a NumPy array with vectorized indexing.
608+
Each element is an integer or array: broadcasting them together gives
609+
the shape of the result.
606610
"""
607-
if len([k for k in key if not isinstance(k, slice)]) <= 1:
608-
# If there is only one vector and all others are slice,
609-
# it can be safely used in mixed basic/advanced indexing.
610-
# Boolean index should already be converted to integer array.
611-
return tuple(key)
612-
613611
n_dim = len([k for k in key if not isinstance(k, integer_types)])
614612
i_dim = 0
615613
new_key = []
@@ -627,6 +625,149 @@ def _outer_to_numpy_indexer(key, shape):
627625
return tuple(new_key)
628626

629627

628+
def _outer_to_numpy_indexer(key, shape):
629+
"""Convert an OuterIndexer into an indexer for NumPy.
630+
631+
Parameters
632+
----------
633+
key : tuple
634+
Tuple from an OuterIndexer to convert.
635+
shape : tuple
636+
Shape of the array subject to the indexing.
637+
638+
Returns
639+
-------
640+
tuple
641+
Tuple suitable for use to index a NumPy array.
642+
"""
643+
if len([k for k in key if not isinstance(k, slice)]) <= 1:
644+
# If there is only one vector and all others are slice,
645+
# it can be safely used in mixed basic/advanced indexing.
646+
# Boolean index should already be converted to integer array.
647+
return tuple(key)
648+
else:
649+
return _outer_to_vectorized_indexer(key, shape)
650+
651+
652+
def _dask_array_with_chunks_hint(array, chunks):
653+
"""Create a dask array using the chunks hint for dimensions of size > 1."""
654+
import dask.array as da
655+
if len(chunks) < array.ndim:
656+
raise ValueError('not enough chunks in hint')
657+
new_chunks = []
658+
for chunk, size in zip(chunks, array.shape):
659+
new_chunks.append(chunk if size > 1 else (1,))
660+
return da.from_array(array, new_chunks)
661+
662+
663+
def _logical_any(args):
664+
return functools.reduce(operator.or_, args)
665+
666+
667+
def _masked_result_drop_slice(key, chunks_hint=None):
668+
key = (k for k in key if not isinstance(k, slice))
669+
if chunks_hint is not None:
670+
key = [_dask_array_with_chunks_hint(k, chunks_hint)
671+
if isinstance(k, np.ndarray) else k
672+
for k in key]
673+
return _logical_any(k == -1 for k in key)
674+
675+
676+
def create_mask(indexer, shape, chunks_hint=None):
677+
"""Create a mask for indexing with a fill-value.
678+
679+
Parameters
680+
----------
681+
indexer : ExplicitIndexer
682+
Indexer with -1 in integer or ndarray value to indicate locations in
683+
the result that should be masked.
684+
shape : tuple
685+
Shape of the array being indexed.
686+
chunks_hint : tuple, optional
687+
Optional tuple indicating desired chunks for the result. If provided,
688+
used as a hint for chunks on the resulting dask. Must have a hint for
689+
each dimension on the result array.
690+
691+
Returns
692+
-------
693+
mask : bool, np.ndarray or dask.array.Array with dtype=bool
694+
Dask array if chunks_hint is provided, otherwise a NumPy array. Has the
695+
same shape as the indexing result.
696+
"""
697+
if isinstance(indexer, OuterIndexer):
698+
key = _outer_to_vectorized_indexer(indexer.tuple, shape)
699+
assert not any(isinstance(k, slice) for k in key)
700+
mask = _masked_result_drop_slice(key, chunks_hint)
701+
702+
elif isinstance(indexer, VectorizedIndexer):
703+
key = indexer.tuple
704+
base_mask = _masked_result_drop_slice(key, chunks_hint)
705+
slice_shape = tuple(np.arange(*k.indices(size)).size
706+
for k, size in zip(key, shape)
707+
if isinstance(k, slice))
708+
expanded_mask = base_mask[
709+
(Ellipsis,) + (np.newaxis,) * len(slice_shape)]
710+
mask = duck_array_ops.broadcast_to(
711+
expanded_mask, base_mask.shape + slice_shape)
712+
713+
elif isinstance(indexer, BasicIndexer):
714+
mask = any(k == -1 for k in indexer.tuple)
715+
716+
else:
717+
raise TypeError('unexpected key type: {}'.format(type(indexer)))
718+
719+
return mask
720+
721+
722+
def _posify_mask_subindexer(index):
723+
"""Convert masked indices in a flat array to the nearest unmasked index.
724+
725+
Parameters
726+
----------
727+
index : np.ndarray
728+
One dimensional ndarray with dtype=int.
729+
730+
Returns
731+
-------
732+
np.ndarray
733+
One dimensional ndarray with all values equal to -1 replaced by an
734+
adjacent non-masked element.
735+
"""
736+
masked = index == -1
737+
unmasked_locs = np.flatnonzero(~masked)
738+
if not unmasked_locs.size:
739+
# indexing unmasked_locs is invalid
740+
return np.zeros_like(index)
741+
masked_locs = np.flatnonzero(masked)
742+
prev_value = np.maximum(0, np.searchsorted(unmasked_locs, masked_locs) - 1)
743+
new_index = index.copy()
744+
new_index[masked_locs] = index[unmasked_locs[prev_value]]
745+
return new_index
746+
747+
748+
def posify_mask_indexer(indexer):
749+
"""Convert masked values (-1) in an indexer to nearest unmasked values.
750+
751+
This routine is useful for dask, where it can be much faster to index
752+
adjacent points than arbitrary points from the end of an array.
753+
754+
Parameters
755+
----------
756+
indexer : ExplicitIndexer
757+
Input indexer.
758+
759+
Returns
760+
-------
761+
ExplicitIndexer
762+
Same type of input, with all values in ndarray keys equal to -1
763+
replaced by an adjacent non-masked element.
764+
"""
765+
key = tuple(_posify_mask_subindexer(k.ravel()).reshape(k.shape)
766+
if isinstance(k, np.ndarray) else k
767+
for k in indexer.tuple)
768+
return type(indexer)(key)
769+
770+
630771
class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
631772
"""Wrap a NumPy array to use explicit indexing."""
632773

xarray/core/variable.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,6 @@ def data(self, data):
307307
"replacement data must match the Variable's shape")
308308
self._data = data
309309

310-
@property
311-
def _indexable_data(self):
312-
return as_indexable(self._data)
313-
314310
def load(self, **kwargs):
315311
"""Manually trigger loading of this variable's data from disk or a
316312
remote source into memory and return this variable.
@@ -622,13 +618,56 @@ def __getitem__(self, key):
622618
If you really want to do indexing like `x[x > 0]`, manipulate the numpy
623619
array `x.values` directly.
624620
"""
625-
dims, index_tuple, new_order = self._broadcast_indexes(key)
626-
data = self._indexable_data[index_tuple]
621+
dims, indexer, new_order = self._broadcast_indexes(key)
622+
data = as_indexable(self._data)[indexer]
627623
if new_order:
628624
data = np.moveaxis(data, range(len(new_order)), new_order)
625+
return self._finalize_indexing_result(dims, data)
626+
627+
def _finalize_indexing_result(self, dims, data):
628+
"""Used by IndexVariable to return IndexVariable objects when possible.
629+
"""
629630
return type(self)(dims, data, self._attrs, self._encoding,
630631
fastpath=True)
631632

633+
def _getitem_with_mask(self, key, fill_value=dtypes.NA):
634+
"""Index this Variable with -1 remapped to fill_value."""
635+
# TODO(shoyer): expose this method in public API somewhere (isel?) and
636+
# use it for reindex.
637+
# TODO(shoyer): add a sanity check that all other integers are
638+
# non-negative
639+
# TODO(shoyer): add an optimization, remapping -1 to an adjacent value
640+
# that is actually indexed rather than mapping it to the last value
641+
# along each axis.
642+
643+
if fill_value is dtypes.NA:
644+
fill_value = dtypes.get_fill_value(self.dtype)
645+
646+
dims, indexer, new_order = self._broadcast_indexes(key)
647+
648+
if self.size:
649+
if isinstance(self._data, dask_array_type):
650+
# dask's indexing is faster this way; also vindex does not
651+
# support negative indices yet:
652+
# https://github.com/dask/dask/pull/2967
653+
actual_indexer = indexing.posify_mask_indexer(indexer)
654+
else:
655+
actual_indexer = indexer
656+
657+
data = as_indexable(self._data)[actual_indexer]
658+
chunks_hint = getattr(data, 'chunks', None)
659+
mask = indexing.create_mask(indexer, self.shape, chunks_hint)
660+
data = duck_array_ops.where(mask, fill_value, data)
661+
else:
662+
# array cannot be indexed along dimensions of size 0, so just
663+
# build the mask directly instead.
664+
mask = indexing.create_mask(indexer, self.shape)
665+
data = np.broadcast_to(fill_value, getattr(mask, 'shape', ()))
666+
667+
if new_order:
668+
data = np.moveaxis(data, range(len(new_order)), new_order)
669+
return self._finalize_indexing_result(dims, data)
670+
632671
def __setitem__(self, key, value):
633672
"""__setitem__ is overloaded to access the underlying numpy values with
634673
orthogonal indexing.
@@ -657,7 +696,8 @@ def __setitem__(self, key, value):
657696
(Ellipsis,)]
658697
value = np.moveaxis(value, new_order, range(len(new_order)))
659698

660-
self._indexable_data[index_tuple] = value
699+
indexable = as_indexable(self._data)
700+
indexable[index_tuple] = value
661701

662702
@property
663703
def attrs(self):
@@ -1468,14 +1508,12 @@ def chunk(self, chunks=None, name=None, lock=False):
14681508
# Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk()
14691509
return self.copy(deep=False)
14701510

1471-
def __getitem__(self, key):
1472-
dims, index_tuple, new_order = self._broadcast_indexes(key)
1473-
values = self._indexable_data[index_tuple]
1474-
if getattr(values, 'ndim', 0) != 1:
1511+
def _finalize_indexing_result(self, dims, data):
1512+
if getattr(data, 'ndim', 0) != 1:
14751513
# returns Variable rather than IndexVariable if multi-dimensional
1476-
return Variable(dims, values, self._attrs, self._encoding)
1514+
return Variable(dims, data, self._attrs, self._encoding)
14771515
else:
1478-
return type(self)(dims, values, self._attrs,
1516+
return type(self)(dims, data, self._attrs,
14791517
self._encoding, fastpath=True)
14801518

14811519
def __setitem__(self, key, value):

xarray/tests/test_indexing.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,79 @@ def nonzero(x):
388388
actual = indexing._outer_to_numpy_indexer(outer_index, v.shape)
389389
actual_data = v.data[actual]
390390
np.testing.assert_array_equal(actual_data, expected_data)
391+
392+
393+
def test_create_mask_outer_indexer():
394+
indexer = indexing.OuterIndexer((np.array([0, -1, 2]),))
395+
expected = np.array([False, True, False])
396+
actual = indexing.create_mask(indexer, (5,))
397+
np.testing.assert_array_equal(expected, actual)
398+
399+
indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]),))
400+
expected = np.array(2 * [[False, True, False]])
401+
actual = indexing.create_mask(indexer, (5, 5, 5,))
402+
np.testing.assert_array_equal(expected, actual)
403+
404+
405+
def test_create_mask_vectorized_indexer():
406+
indexer = indexing.VectorizedIndexer(
407+
(np.array([0, -1, 2]), np.array([0, 1, -1])))
408+
expected = np.array([False, True, True])
409+
actual = indexing.create_mask(indexer, (5,))
410+
np.testing.assert_array_equal(expected, actual)
411+
412+
indexer = indexing.VectorizedIndexer(
413+
(np.array([0, -1, 2]), slice(None), np.array([0, 1, -1])))
414+
expected = np.array([[False, True, True]] * 2).T
415+
actual = indexing.create_mask(indexer, (5, 2))
416+
np.testing.assert_array_equal(expected, actual)
417+
418+
419+
def test_create_mask_basic_indexer():
420+
indexer = indexing.BasicIndexer((-1,))
421+
actual = indexing.create_mask(indexer, (3,))
422+
np.testing.assert_array_equal(True, actual)
423+
424+
indexer = indexing.BasicIndexer((0,))
425+
actual = indexing.create_mask(indexer, (3,))
426+
np.testing.assert_array_equal(False, actual)
427+
428+
429+
def test_create_mask_dask():
430+
da = pytest.importorskip('dask.array')
431+
432+
indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]),))
433+
expected = np.array(2 * [[False, True, False]])
434+
actual = indexing.create_mask(indexer, (5, 5, 5,),
435+
chunks_hint=((1, 1), (2, 1)))
436+
assert actual.chunks == ((1, 1), (2, 1))
437+
np.testing.assert_array_equal(expected, actual)
438+
439+
indexer = indexing.VectorizedIndexer(
440+
(np.array([0, -1, 2]), slice(None), np.array([0, 1, -1])))
441+
expected = np.array([[False, True, True]] * 2).T
442+
actual = indexing.create_mask(indexer, (5, 2), chunks_hint=((3,), (2,)))
443+
assert isinstance(actual, da.Array)
444+
np.testing.assert_array_equal(expected, actual)
445+
446+
with pytest.raises(ValueError):
447+
indexing.create_mask(indexer, (5, 2), chunks_hint=())
448+
449+
450+
def test_create_mask_error():
451+
with raises_regex(TypeError, 'unexpected key type'):
452+
indexing.create_mask((1, 2), (3, 4))
453+
454+
455+
@pytest.mark.parametrize('indices, expected', [
456+
(np.arange(5), np.arange(5)),
457+
(np.array([0, -1, -1]), np.array([0, 0, 0])),
458+
(np.array([-1, 1, -1]), np.array([1, 1, 1])),
459+
(np.array([-1, -1, 2]), np.array([2, 2, 2])),
460+
(np.array([-1]), np.array([0])),
461+
(np.array([0, -1, 1, -1, -1]), np.array([0, 0, 1, 1, 1])),
462+
(np.array([0, -1, -1, -1, 1]), np.array([0, 0, 0, 0, 1])),
463+
])
464+
def test_posify_mask_subindexer(indices, expected):
465+
actual = indexing._posify_mask_subindexer(indices)
466+
np.testing.assert_array_equal(expected, actual)

0 commit comments

Comments
 (0)