Skip to content

Commit 6e14df6

Browse files
authored
Flexible indexes: add Index base class and xindexes properties (#5102)
* add IndexAdapter class + move PandasIndexAdapter * wip: xarray_obj.indexes -> IndexAdapter objects * fix more broken tests * fix merge glitch * fix group bins tests * add xindexes property Use it internally instead of indexes * rename IndexAdapter -> Index * rename _to_index_adpater (typo) -> _to_xindex * add Index.to_pandas_index() method Also improve xarray_obj.indexes property implementation * rename PandasIndexAdpater -> PandasIndex * update index type in tests * ensure .indexes only returns pd.Index objects * PandasIndex: normalize other index in cmp funcs * fix merge lint errors * fix PandasIndex union/intersection * [skip-ci] add TODO comment about index sizes * address more PR comments * [skip-ci] update what's new * fix coord_names normalization * move what's new entry to unreleased section
1 parent 234b40a commit 6e14df6

22 files changed

+534
-311
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ Documentation
4242
Internal Changes
4343
~~~~~~~~~~~~~~~~
4444

45+
- Explicit indexes refactor: add an ``xarray.Index`` base class and
46+
``Dataset.xindexes`` / ``DataArray.xindexes`` properties. Also rename
47+
``PandasIndexAdapter`` to ``PandasIndex``, which now inherits from
48+
``xarray.Index`` (:pull:`5102`). By `Benoit Bovy <https://github.com/benbovy>`_.
4549

4650
.. _whats-new.0.18.0:
4751

@@ -268,7 +272,6 @@ Internal Changes
268272
(:pull:`5188`), (:pull:`5191`).
269273
By `Maximilian Roos <https://github.com/max-sixty>`_.
270274

271-
272275
.. _whats-new.0.17.0:
273276

274277
v0.17.0 (24 Feb 2021)

xarray/core/alignment.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import numpy as np
1818
import pandas as pd
1919

20-
from . import dtypes, utils
20+
from . import dtypes
21+
from .indexes import Index, PandasIndex
2122
from .indexing import get_indexer_nd
22-
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str
23+
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index
2324
from .variable import IndexVariable, Variable
2425

2526
if TYPE_CHECKING:
@@ -30,11 +31,11 @@
3031
DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)
3132

3233

33-
def _get_joiner(join):
34+
def _get_joiner(join, index_cls):
3435
if join == "outer":
35-
return functools.partial(functools.reduce, pd.Index.union)
36+
return functools.partial(functools.reduce, index_cls.union)
3637
elif join == "inner":
37-
return functools.partial(functools.reduce, pd.Index.intersection)
38+
return functools.partial(functools.reduce, index_cls.intersection)
3839
elif join == "left":
3940
return operator.itemgetter(0)
4041
elif join == "right":
@@ -63,7 +64,7 @@ def _override_indexes(objects, all_indexes, exclude):
6364
objects = list(objects)
6465
for idx, obj in enumerate(objects[1:]):
6566
new_indexes = {}
66-
for dim in obj.indexes:
67+
for dim in obj.xindexes:
6768
if dim not in exclude:
6869
new_indexes[dim] = all_indexes[dim][0]
6970
objects[idx + 1] = obj._overwrite_indexes(new_indexes)
@@ -284,7 +285,7 @@ def align(
284285
if dim not in exclude:
285286
all_coords[dim].append(obj.coords[dim])
286287
try:
287-
index = obj.indexes[dim]
288+
index = obj.xindexes[dim]
288289
except KeyError:
289290
unlabeled_dim_sizes[dim].add(obj.sizes[dim])
290291
else:
@@ -298,16 +299,19 @@ def align(
298299
# - It ensures it's possible to do operations that don't require alignment
299300
# on indexes with duplicate values (which cannot be reindexed with
300301
# pandas). This is useful, e.g., for overwriting such duplicate indexes.
301-
joiner = _get_joiner(join)
302302
joined_indexes = {}
303303
for dim, matching_indexes in all_indexes.items():
304304
if dim in indexes:
305-
index = utils.safe_cast_to_index(indexes[dim])
305+
# TODO: benbovy - flexible indexes. maybe move this logic in util func
306+
if isinstance(indexes[dim], Index):
307+
index = indexes[dim]
308+
else:
309+
index = PandasIndex(safe_cast_to_index(indexes[dim]))
306310
if (
307311
any(not index.equals(other) for other in matching_indexes)
308312
or dim in unlabeled_dim_sizes
309313
):
310-
joined_indexes[dim] = indexes[dim]
314+
joined_indexes[dim] = index
311315
else:
312316
if (
313317
any(
@@ -318,6 +322,7 @@ def align(
318322
):
319323
if join == "exact":
320324
raise ValueError(f"indexes along dimension {dim!r} are not equal")
325+
joiner = _get_joiner(join, type(matching_indexes[0]))
321326
index = joiner(matching_indexes)
322327
# make sure str coords are not cast to object
323328
index = maybe_coerce_to_str(index, all_coords[dim])
@@ -327,6 +332,9 @@ def align(
327332

328333
if dim in unlabeled_dim_sizes:
329334
unlabeled_sizes = unlabeled_dim_sizes[dim]
335+
# TODO: benbovy - flexible indexes: expose a size property for xarray.Index?
336+
# Some indexes may not have a defined size (e.g., built from multiple coords of
337+
# different sizes)
330338
labeled_size = index.size
331339
if len(unlabeled_sizes | {labeled_size}) > 1:
332340
raise ValueError(
@@ -469,7 +477,7 @@ def reindex_like_indexers(
469477
ValueError
470478
If any dimensions without labels have different sizes.
471479
"""
472-
indexers = {k: v for k, v in other.indexes.items() if k in target.dims}
480+
indexers = {k: v for k, v in other.xindexes.items() if k in target.dims}
473481

474482
for dim in other.dims:
475483
if dim not in indexers and dim in target.dims:
@@ -487,14 +495,14 @@ def reindex_like_indexers(
487495
def reindex_variables(
488496
variables: Mapping[Any, Variable],
489497
sizes: Mapping[Any, int],
490-
indexes: Mapping[Any, pd.Index],
498+
indexes: Mapping[Any, Index],
491499
indexers: Mapping,
492500
method: Optional[str] = None,
493501
tolerance: Any = None,
494502
copy: bool = True,
495503
fill_value: Optional[Any] = dtypes.NA,
496504
sparse: bool = False,
497-
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]:
505+
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]:
498506
"""Conform a dictionary of aligned variables onto a new set of variables,
499507
filling in missing values with NaN.
500508
@@ -559,10 +567,11 @@ def reindex_variables(
559567
"from that to be indexed along {:s}".format(str(indexer.dims), dim)
560568
)
561569

562-
target = new_indexes[dim] = utils.safe_cast_to_index(indexers[dim])
570+
target = new_indexes[dim] = PandasIndex(safe_cast_to_index(indexers[dim]))
563571

564572
if dim in indexes:
565-
index = indexes[dim]
573+
# TODO (benbovy - flexible indexes): support other indexes than pd.Index?
574+
index = indexes[dim].to_pandas_index()
566575

567576
if not index.is_unique:
568577
raise ValueError(

xarray/core/combine.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,17 @@ def _infer_concat_order_from_coords(datasets):
6969
if dim in ds0:
7070

7171
# Need to read coordinate values to do ordering
72-
indexes = [ds.indexes.get(dim) for ds in datasets]
72+
indexes = [ds.xindexes.get(dim) for ds in datasets]
7373
if any(index is None for index in indexes):
7474
raise ValueError(
7575
"Every dimension needs a coordinate for "
7676
"inferring concatenation order"
7777
)
7878

79+
# TODO (benbovy, flexible indexes): all indexes should be Pandas.Index
80+
# get pd.Index objects from Index objects
81+
indexes = [index.array for index in indexes]
82+
7983
# If dimension coordinate values are same on every dataset then
8084
# should be leaving this dimension alone (it's just a "bystander")
8185
if not all(index.equals(indexes[0]) for index in indexes[1:]):
@@ -801,9 +805,13 @@ def combine_by_coords(
801805
)
802806

803807
# Check the overall coordinates are monotonically increasing
808+
# TODO (benbovy - flexible indexes): only with pandas.Index?
804809
for dim in concat_dims:
805-
indexes = concatenated.indexes.get(dim)
806-
if not (indexes.is_monotonic_increasing or indexes.is_monotonic_decreasing):
810+
indexes = concatenated.xindexes.get(dim)
811+
if not (
812+
indexes.array.is_monotonic_increasing
813+
or indexes.array.is_monotonic_decreasing
814+
):
807815
raise ValueError(
808816
"Resulting object does not have monotonic"
809817
" global indexes along dimension {}".format(dim)

xarray/core/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def get_index(self, key: Hashable) -> pd.Index:
406406
raise KeyError(key)
407407

408408
try:
409-
return self.indexes[key]
409+
return self.xindexes[key].to_pandas_index()
410410
except KeyError:
411411
return pd.Index(range(self.sizes[key]), name=key)
412412

@@ -1162,7 +1162,8 @@ def resample(
11621162
category=FutureWarning,
11631163
)
11641164

1165-
if isinstance(self.indexes[dim_name], CFTimeIndex):
1165+
# TODO (benbovy - flexible indexes): update when CFTimeIndex is an xarray Index subclass
1166+
if isinstance(self.xindexes[dim_name].to_pandas_index(), CFTimeIndex):
11661167
from .resample_cftime import CFTimeGrouper
11671168

11681169
grouper = CFTimeGrouper(freq, closed, label, base, loffset)

xarray/core/coordinates.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pandas as pd
1818

1919
from . import formatting, indexing
20-
from .indexes import Indexes
20+
from .indexes import Index, Indexes
2121
from .merge import merge_coordinates_without_align, merge_coords
2222
from .utils import Frozen, ReprObject, either_dict_or_kwargs
2323
from .variable import Variable
@@ -52,6 +52,10 @@ def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]:
5252
def indexes(self) -> Indexes:
5353
return self._data.indexes # type: ignore[attr-defined]
5454

55+
@property
56+
def xindexes(self) -> Indexes:
57+
return self._data.xindexes # type: ignore[attr-defined]
58+
5559
@property
5660
def variables(self):
5761
raise NotImplementedError()
@@ -157,15 +161,15 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
157161
def update(self, other: Mapping[Hashable, Any]) -> None:
158162
other_vars = getattr(other, "variables", other)
159163
coords, indexes = merge_coords(
160-
[self.variables, other_vars], priority_arg=1, indexes=self.indexes
164+
[self.variables, other_vars], priority_arg=1, indexes=self.xindexes
161165
)
162166
self._update_coords(coords, indexes)
163167

164168
def _merge_raw(self, other, reflexive):
165169
"""For use with binary arithmetic."""
166170
if other is None:
167171
variables = dict(self.variables)
168-
indexes = dict(self.indexes)
172+
indexes = dict(self.xindexes)
169173
else:
170174
coord_list = [self, other] if not reflexive else [other, self]
171175
variables, indexes = merge_coordinates_without_align(coord_list)
@@ -180,7 +184,9 @@ def _merge_inplace(self, other):
180184
# don't include indexes in prioritized, because we didn't align
181185
# first and we want indexes to be checked
182186
prioritized = {
183-
k: (v, None) for k, v in self.variables.items() if k not in self.indexes
187+
k: (v, None)
188+
for k, v in self.variables.items()
189+
if k not in self.xindexes
184190
}
185191
variables, indexes = merge_coordinates_without_align(
186192
[self, other], prioritized
@@ -265,7 +271,7 @@ def to_dataset(self) -> "Dataset":
265271
return self._data._copy_listed(names)
266272

267273
def _update_coords(
268-
self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, pd.Index]
274+
self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index]
269275
) -> None:
270276
from .dataset import calculate_dimensions
271277

@@ -285,7 +291,7 @@ def _update_coords(
285291

286292
# TODO(shoyer): once ._indexes is always populated by a dict, modify
287293
# it to update inplace instead.
288-
original_indexes = dict(self._data.indexes)
294+
original_indexes = dict(self._data.xindexes)
289295
original_indexes.update(indexes)
290296
self._data._indexes = original_indexes
291297

@@ -328,7 +334,7 @@ def __getitem__(self, key: Hashable) -> "DataArray":
328334
return self._data._getitem_coord(key)
329335

330336
def _update_coords(
331-
self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, pd.Index]
337+
self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index]
332338
) -> None:
333339
from .dataset import calculate_dimensions
334340

@@ -343,7 +349,7 @@ def _update_coords(
343349

344350
# TODO(shoyer): once ._indexes is always populated by a dict, modify
345351
# it to update inplace instead.
346-
original_indexes = dict(self._data.indexes)
352+
original_indexes = dict(self._data.xindexes)
347353
original_indexes.update(indexes)
348354
self._data._indexes = original_indexes
349355

xarray/core/dataarray.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
)
5252
from .dataset import Dataset, split_indexes
5353
from .formatting import format_item
54-
from .indexes import Indexes, default_indexes, propagate_indexes
54+
from .indexes import Index, Indexes, PandasIndex, default_indexes, propagate_indexes
5555
from .indexing import is_fancy_indexer
5656
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
5757
from .options import OPTIONS, _get_keep_attrs
@@ -345,7 +345,7 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic):
345345
_cache: Dict[str, Any]
346346
_coords: Dict[Any, Variable]
347347
_close: Optional[Callable[[], None]]
348-
_indexes: Optional[Dict[Hashable, pd.Index]]
348+
_indexes: Optional[Dict[Hashable, Index]]
349349
_name: Optional[Hashable]
350350
_variable: Variable
351351

@@ -478,7 +478,9 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
478478
# switch from dimension to level names, if necessary
479479
dim_names: Dict[Any, str] = {}
480480
for dim, idx in indexes.items():
481-
if not isinstance(idx, pd.MultiIndex) and idx.name != dim:
481+
# TODO: benbovy - flexible indexes: update when MultiIndex has its own class
482+
pd_idx = idx.array
483+
if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim:
482484
dim_names[dim] = idx.name
483485
if dim_names:
484486
obj = obj.rename(dim_names)
@@ -772,7 +774,21 @@ def encoding(self, value: Mapping[Hashable, Any]) -> None:
772774

773775
@property
774776
def indexes(self) -> Indexes:
775-
"""Mapping of pandas.Index objects used for label based indexing"""
777+
"""Mapping of pandas.Index objects used for label based indexing.
778+
779+
Raises an error if this Dataset has indexes that cannot be coerced
780+
to pandas.Index objects.
781+
782+
See Also
783+
--------
784+
DataArray.xindexes
785+
786+
"""
787+
return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()})
788+
789+
@property
790+
def xindexes(self) -> Indexes:
791+
"""Mapping of xarray Index objects used for label based indexing."""
776792
if self._indexes is None:
777793
self._indexes = default_indexes(self._coords, self.dims)
778794
return Indexes(self._indexes)
@@ -990,7 +1006,12 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
9901006
if self._indexes is None:
9911007
indexes = self._indexes
9921008
else:
993-
indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()}
1009+
# TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index)
1010+
# xarray Index needs a copy method.
1011+
indexes = {
1012+
k: PandasIndex(v.to_pandas_index().copy(deep=deep))
1013+
for k, v in self._indexes.items()
1014+
}
9941015
return self._replace(variable, coords, indexes=indexes)
9951016

9961017
def __copy__(self) -> "DataArray":
@@ -2169,7 +2190,9 @@ def to_unstacked_dataset(self, dim, level=0):
21692190
Dataset.to_stacked_array
21702191
"""
21712192

2172-
idx = self.indexes[dim]
2193+
# TODO: benbovy - flexible indexes: update when MultIndex has its own
2194+
# class inheriting from xarray.Index
2195+
idx = self.xindexes[dim].to_pandas_index()
21732196
if not isinstance(idx, pd.MultiIndex):
21742197
raise ValueError(f"'{dim}' is not a stacked coordinate")
21752198

0 commit comments

Comments
 (0)