Skip to content

Commit 0086e32

Browse files
committed
wip refactor set_index
1 parent 021090f commit 0086e32

File tree

2 files changed

+189
-120
lines changed

2 files changed

+189
-120
lines changed

xarray/core/dataset.py

Lines changed: 84 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -194,90 +194,6 @@ def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, in
194194
return dims
195195

196196

197-
def merge_indexes(
198-
indexes: Mapping[Any, Union[Hashable, Sequence[Hashable]]],
199-
variables: Mapping[Any, Variable],
200-
coord_names: Set[Hashable],
201-
append: bool = False,
202-
) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]:
203-
"""Merge variables into multi-indexes.
204-
205-
Not public API. Used in Dataset and DataArray set_index
206-
methods.
207-
"""
208-
vars_to_replace: Dict[Hashable, Variable] = {}
209-
vars_to_remove: List[Hashable] = []
210-
dims_to_replace: Dict[Hashable, Hashable] = {}
211-
error_msg = "{} is not the name of an existing variable."
212-
213-
for dim, var_names in indexes.items():
214-
if isinstance(var_names, str) or not isinstance(var_names, Sequence):
215-
var_names = [var_names]
216-
217-
names: List[Hashable] = []
218-
codes: List[List[int]] = []
219-
levels: List[List[int]] = []
220-
current_index_variable = variables.get(dim)
221-
222-
for n in var_names:
223-
try:
224-
var = variables[n]
225-
except KeyError:
226-
raise ValueError(error_msg.format(n))
227-
if (
228-
current_index_variable is not None
229-
and var.dims != current_index_variable.dims
230-
):
231-
raise ValueError(
232-
f"dimension mismatch between {dim!r} {current_index_variable.dims} and {n!r} {var.dims}"
233-
)
234-
235-
if current_index_variable is not None and append:
236-
current_index = current_index_variable.to_index()
237-
if isinstance(current_index, pd.MultiIndex):
238-
names.extend(current_index.names)
239-
codes.extend(current_index.codes)
240-
levels.extend(current_index.levels)
241-
else:
242-
names.append(f"{dim}_level_0")
243-
cat = pd.Categorical(current_index.values, ordered=True)
244-
codes.append(cat.codes)
245-
levels.append(cat.categories)
246-
247-
if not len(names) and len(var_names) == 1:
248-
idx = pd.Index(variables[var_names[0]].values)
249-
250-
else: # MultiIndex
251-
for n in var_names:
252-
try:
253-
var = variables[n]
254-
except KeyError:
255-
raise ValueError(error_msg.format(n))
256-
names.append(n)
257-
cat = pd.Categorical(var.values, ordered=True)
258-
codes.append(cat.codes)
259-
levels.append(cat.categories)
260-
261-
idx = pd.MultiIndex(levels, codes, names=names)
262-
for n in names:
263-
dims_to_replace[n] = dim
264-
265-
vars_to_replace[dim] = IndexVariable(dim, idx)
266-
vars_to_remove.extend(var_names)
267-
268-
new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove}
269-
new_variables.update(vars_to_replace)
270-
271-
# update dimensions if necessary, GH: 3512
272-
for k, v in new_variables.items():
273-
if any(d in dims_to_replace for d in v.dims):
274-
new_dims = [dims_to_replace.get(d, d) for d in v.dims]
275-
new_variables[k] = v._replace(dims=new_dims)
276-
new_coord_names = coord_names | set(vars_to_replace)
277-
new_coord_names -= set(vars_to_remove)
278-
return new_variables, new_coord_names
279-
280-
281197
def split_indexes(
282198
dims_or_levels: Union[Hashable, Sequence[Hashable]],
283199
variables: Mapping[Any, Variable],
@@ -3307,7 +3223,7 @@ def _rename_dims(self, name_dict):
33073223

33083224
def _rename_indexes(self, name_dict, dims_dict):
33093225
if self._indexes is None:
3310-
return None, {}
3226+
return {}, {}
33113227

33123228
indexes = {}
33133229
variables = {}
@@ -3751,11 +3667,90 @@ def set_index(
37513667
Dataset.reset_index
37523668
Dataset.swap_dims
37533669
"""
3754-
indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index")
3755-
variables, coord_names = merge_indexes(
3756-
indexes, self._variables, self._coord_names, append=append
3670+
dim_coords = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index")
3671+
3672+
new_indexes: Dict[Hashable, Index] = {}
3673+
new_variables: Dict[Hashable, IndexVariable] = {}
3674+
maybe_drop_indexes: List[Hashable] = []
3675+
drop_variables: List[Hashable] = []
3676+
replace_dims: Dict[Hashable, Hashable] = {}
3677+
3678+
index_coord_names = {
3679+
k: coord_names
3680+
for _, coord_names in group_coords_by_index(self.xindexes)
3681+
for k in coord_names
3682+
}
3683+
3684+
for dim, _var_names in dim_coords.items():
3685+
if isinstance(_var_names, str) or not isinstance(_var_names, Sequence):
3686+
var_names = [_var_names]
3687+
else:
3688+
var_names = list(_var_names)
3689+
3690+
invalid_vars = set(var_names) - set(self._variables)
3691+
if invalid_vars:
3692+
raise ValueError(
3693+
", ".join([str(v) for v in invalid_vars])
3694+
+ " variable(s) do not exist"
3695+
)
3696+
3697+
current_coord_names = index_coord_names.get(dim, [])
3698+
3699+
# drop any pre-existing index involved
3700+
maybe_drop_indexes.extend(current_coord_names + var_names)
3701+
for k in var_names:
3702+
maybe_drop_indexes.extend(index_coord_names.get(k, []))
3703+
3704+
drop_variables.extend(var_names)
3705+
3706+
if len(var_names) == 1 and (not append or dim not in self.xindexes):
3707+
var_name = var_names[0]
3708+
var = self._variables[var_name]
3709+
if var.dims != (dim,):
3710+
raise ValueError(
3711+
f"dimension mismatch: try setting an index for dimension {dim!r} with "
3712+
f"variable {var_name!r} that has dimensions {var.dims}"
3713+
)
3714+
idx, idx_vars = PandasIndex.from_variables({dim: var})
3715+
else:
3716+
if append:
3717+
current_variables = {
3718+
k: self._variables[k] for k in current_coord_names
3719+
}
3720+
else:
3721+
current_variables = {}
3722+
idx, idx_vars = PandasMultiIndex.from_variables_maybe_expand(
3723+
dim,
3724+
current_variables,
3725+
{k: self._variables[k] for k in var_names},
3726+
)
3727+
for n in idx.index.names:
3728+
replace_dims[n] = dim
3729+
3730+
new_indexes.update({k: idx for k in idx_vars})
3731+
new_variables.update(idx_vars)
3732+
3733+
indexes_: Dict[Any, Index] = {
3734+
k: v for k, v in self.xindexes.items() if k not in maybe_drop_indexes
3735+
}
3736+
indexes_.update(new_indexes)
3737+
3738+
variables = {
3739+
k: v for k, v in self._variables.items() if k not in drop_variables
3740+
}
3741+
variables.update(new_variables)
3742+
3743+
# update dimensions if necessary, GH: 3512
3744+
for k, v in variables.items():
3745+
if any(d in replace_dims for d in v.dims):
3746+
new_dims = [replace_dims.get(d, d) for d in v.dims]
3747+
variables[k] = v._replace(dims=new_dims)
3748+
3749+
coord_names = set(new_variables) | self._coord_names
3750+
3751+
return self._replace_with_new_dims(
3752+
variables, coord_names=coord_names, indexes=indexes_
37573753
)
3758-
return self._replace_vars_and_dims(variables, coord_names=coord_names)
37593754

37603755
def reset_index(
37613756
self,

xarray/core/indexes.py

Lines changed: 105 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -312,32 +312,56 @@ def __getitem__(self, indexer: Any):
312312
return self._replace(self.index[indexer])
313313

314314

315-
def _create_variables_from_multiindex(index, dim, level_meta=None):
316-
from .variable import IndexVariable
315+
def _check_dim_compat(variables: Mapping[Any, "Variable"]) -> Hashable:
316+
"""Check that all multi-index variable candidates share the same (single) dimension
317+
and return the name of that dimension.
318+
319+
"""
320+
if any([var.ndim != 1 for var in variables.values()]):
321+
raise ValueError("PandasMultiIndex only accepts 1-dimensional variables")
317322

318-
if level_meta is None:
319-
level_meta = {}
323+
dims = set([var.dims for var in variables.values()])
320324

321-
variables = {}
325+
if len(dims) > 1:
326+
raise ValueError(
327+
"unmatched dimensions for variables "
328+
+ ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()])
329+
)
322330

323-
dim_coord_adapter = PandasMultiIndexingAdapter(index)
324-
variables[dim] = IndexVariable(dim, dim_coord_adapter, fastpath=True)
331+
return next(iter(dims))[0]
325332

326-
for level in index.names:
327-
meta = level_meta.get(level, {})
333+
334+
def _create_variables_from_multiindex(index, dim, var_meta=None):
335+
from .variable import IndexVariable
336+
337+
if var_meta is None:
338+
var_meta = {}
339+
340+
def create_variable(name):
341+
if name == dim:
342+
level = None
343+
else:
344+
level = name
345+
meta = var_meta.get(name, {})
328346
data = PandasMultiIndexingAdapter(index, dtype=meta.get("dtype"), level=level)
329-
variables[level] = IndexVariable(
347+
return IndexVariable(
330348
dim,
331349
data,
332350
attrs=meta.get("attrs"),
333351
encoding=meta.get("encoding"),
334352
fastpath=True,
335353
)
336354

355+
variables = {}
356+
variables[dim] = create_variable(dim)
357+
for level in index.names:
358+
variables[level] = create_variable(level)
359+
337360
return variables
338361

339362

340363
class PandasMultiIndex(PandasIndex):
364+
"""Wrap a pandas.MultiIndex as an xarray compatible index."""
341365

342366
level_coords_dtype: Dict[str, Any]
343367

@@ -358,51 +382,101 @@ def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiInde
358382
return type(self)(index, dim, level_coords_dtype)
359383

360384
@classmethod
361-
def from_variables(cls, variables: Mapping[Any, "Variable"]):
362-
if any([var.ndim != 1 for var in variables.values()]):
363-
raise ValueError("PandasMultiIndex only accepts 1-dimensional variables")
364-
365-
dims = set([var.dims for var in variables.values()])
366-
if len(dims) != 1:
367-
raise ValueError(
368-
"unmatched dimensions for variables "
369-
+ ",".join([str(k) for k in variables])
370-
)
385+
def from_variables(
386+
cls, variables: Mapping[Any, "Variable"]
387+
) -> Tuple["PandasMultiIndex", IndexVars]:
388+
dim = _check_dim_compat(variables)
371389

372-
dim = next(iter(dims))[0]
373390
index = pd.MultiIndex.from_arrays(
374391
[var.values for var in variables.values()], names=variables.keys()
375392
)
376393
level_coords_dtype = {name: var.dtype for name, var in variables.items()}
377394
obj = cls(index, dim, level_coords_dtype=level_coords_dtype)
378395

379-
level_meta = {
396+
var_meta = {
380397
name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding}
381398
for name, var in variables.items()
382399
}
383-
index_vars = _create_variables_from_multiindex(
384-
index, dim, level_meta=level_meta
385-
)
400+
index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta)
401+
402+
return obj, index_vars
403+
404+
@classmethod
405+
def from_variables_maybe_expand(
406+
cls,
407+
dim: Hashable,
408+
current_variables: Mapping[Any, "Variable"],
409+
variables: Mapping[Any, "Variable"],
410+
) -> Tuple["PandasMultiIndex", IndexVars]:
411+
"""Create a new multi-index maybe by expanding an existing one with
412+
new variables as index levels.
413+
414+
the index might be created along a new dimension.
415+
"""
416+
names: List[Hashable] = []
417+
codes: List[List[int]] = []
418+
levels: List[List[int]] = []
419+
var_meta: Dict[str, Dict] = {}
420+
level_coords_dtype: Dict[Hashable, Any] = {}
421+
422+
_check_dim_compat({**current_variables, **variables})
423+
424+
def add_level_var(name, var):
425+
var_meta[name] = {
426+
"dtype": var.dtype,
427+
"attrs": var.attrs,
428+
"encoding": var.encoding,
429+
}
430+
level_coords_dtype[name] = var.dtype
431+
432+
if len(current_variables) > 1:
433+
current_index: pd.MultiIndex = next(
434+
iter(current_variables.values())
435+
)._data.array
436+
names.extend(current_index.names)
437+
codes.extend(current_index.codes)
438+
levels.extend(current_index.levels)
439+
for name in current_index.names:
440+
add_level_var(name, current_variables[name])
441+
442+
elif len(current_variables) == 1:
443+
# one 1D variable (no multi-index): convert it to an index level
444+
var = next(iter(current_variables.values()))
445+
new_var_name = f"{dim}_level_0"
446+
names.append(new_var_name)
447+
cat = pd.Categorical(var.values, ordered=True)
448+
codes.append(cat.codes)
449+
levels.append(cat.categories)
450+
add_level_var(new_var_name, var)
451+
452+
for name, var in variables.items():
453+
names.append(name)
454+
cat = pd.Categorical(var.values, ordered=True)
455+
codes.append(cat.codes)
456+
levels.append(cat.categories)
457+
add_level_var(name, var)
458+
459+
index = pd.MultiIndex(levels, codes, names=names)
460+
obj = cls(index, dim, level_coords_dtype=level_coords_dtype)
461+
index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta)
386462

387463
return obj, index_vars
388464

389465
@classmethod
390466
def from_pandas_index(
391467
cls, index: pd.MultiIndex, dim: Hashable
392468
) -> Tuple["PandasMultiIndex", IndexVars]:
393-
level_meta = {}
469+
var_meta = {}
394470
for i, idx in enumerate(index.levels):
395471
name = idx.name or f"{dim}_level_{i}"
396472
if name == dim:
397473
raise ValueError(
398474
f"conflicting multi-index level name {name!r} with dimension {dim!r}"
399475
)
400-
level_meta[name] = {"dtype": idx.dtype}
476+
var_meta[name] = {"dtype": idx.dtype}
401477

402-
index = index.rename(level_meta.keys())
403-
index_vars = _create_variables_from_multiindex(
404-
index, dim, level_meta=level_meta
405-
)
478+
index = index.rename(var_meta.keys())
479+
index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta)
406480
return cls(index, dim), index_vars
407481

408482
def query(self, labels, method=None, tolerance=None) -> QueryResult:

0 commit comments

Comments
 (0)