Skip to content

Commit 3cae0ad

Browse files
committed
Return dataclass from factorize
1 parent 589e897 commit 3cae0ad

File tree

1 file changed

+71
-40
lines changed

1 file changed

+71
-40
lines changed

xarray/core/groupby.py

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@
5454
GroupKey = Any
5555
GroupIndex = Union[int, slice, list[int]]
5656
T_GroupIndices = list[GroupIndex]
57-
T_FactorizeOut = tuple[
58-
DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index, DataArray
59-
]
6057

6158

6259
def check_reduce_dims(reduce_dims, dimensions):
@@ -92,7 +89,7 @@ def _maybe_squeeze_indices(
9289

9390
def unique_value_groups(
9491
ar, sort: bool = True
95-
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
92+
) -> tuple[np.ndarray | pd.Index, np.ndarray]:
9693
"""Group an array by its unique values.
9794
9895
Parameters
@@ -113,11 +110,11 @@ def unique_value_groups(
113110
inverse, values = pd.factorize(ar, sort=sort)
114111
if isinstance(values, pd.MultiIndex):
115112
values.names = ar.names
116-
groups = _codes_to_groups(inverse, len(values))
117-
return values, groups, inverse
113+
return values, inverse
118114

119115

120-
def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices:
116+
def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices:
117+
assert inverse.ndim == 1
121118
groups: T_GroupIndices = [[] for _ in range(N)]
122119
for n, g in enumerate(inverse):
123120
if g >= 0:
@@ -341,16 +338,35 @@ def _apply_loffset(
341338

342339

343340
@dataclass
344-
class ResolvedGrouper:
341+
class EncodedGroups:
342+
"""
343+
Parameters
344+
----------
345+
codes:
346+
full_index:
347+
group_indices: optional,
348+
Inferred if not provided.
349+
unique_coord:
350+
Inferred if not provided
351+
"""
352+
353+
codes: DataArray
354+
full_index: pd.Index
355+
group_indices: T_GroupIndices | None = field(default=None)
356+
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
357+
358+
359+
@dataclass
360+
class ResolvedGrouper(Generic[T_Xarray]):
345361
grouper: Grouper
346362
group: T_Group
347363
obj: T_Xarray
348364

349-
# Defined by factorize:
365+
# returned by factorize:
350366
codes: DataArray = field(init=False)
367+
full_index: pd.Index = field(init=False)
351368
group_indices: T_GroupIndices = field(init=False)
352369
unique_coord: IndexVariable | _DummyGroup = field(init=False)
353-
full_index: pd.Index = field(init=False)
354370

355371
# _ensure_1d:
356372
group1d: T_Group = field(init=False)
@@ -394,20 +410,29 @@ def dims(self):
394410
return self.group1d.dims
395411

396412
def factorize(self) -> None:
397-
# This design makes it clear to mypy that
398-
# codes, group_indices, unique_coord, and full_index
399-
# are set by the factorize method on the derived class.
400-
(
401-
self.codes,
402-
self.group_indices,
403-
self.unique_coord,
404-
self.full_index,
405-
) = self.grouper.factorize(self.group1d)
413+
encoded = self.grouper.factorize(self.group1d)
414+
415+
self.codes = encoded.codes
416+
self.full_index = encoded.full_index
417+
418+
if encoded.group_indices is not None:
419+
self.group_indices = encoded.group_indices
420+
else:
421+
self.group_indices = [
422+
g
423+
for g in _codes_to_group_indices(self.codes.data, len(self.full_index))
424+
if g
425+
]
426+
if encoded.unique_coord is None:
427+
# TODO
428+
raise NotImplementedError
429+
else:
430+
self.unique_coord = encoded.unique_coord
406431

407432

408433
class Grouper(ABC):
409434
@abstractmethod
410-
def factorize(self, group) -> T_FactorizeOut:
435+
def factorize(self, group: T_Group) -> EncodedGroups:
411436
pass
412437

413438

@@ -437,34 +462,33 @@ def can_squeeze(self) -> bool:
437462
is_dimension = self.group.dims == (self.group.name,)
438463
return is_dimension and self.is_unique_and_monotonic
439464

440-
def factorize(self, group1d) -> T_FactorizeOut:
465+
def factorize(self, group1d) -> EncodedGroups:
441466
self.group = group1d
442467

443468
if self.can_squeeze:
444469
return self._factorize_dummy()
445470
else:
446471
return self._factorize_unique()
447472

448-
def _factorize_unique(self) -> T_FactorizeOut:
473+
def _factorize_unique(self) -> EncodedGroups:
449474
# look through group to find the unique values
450475
sort = not isinstance(self.group_as_index, pd.MultiIndex)
451-
unique_values, group_indices, codes_ = unique_value_groups(
452-
self.group_as_index, sort=sort
453-
)
454-
if len(group_indices) == 0:
476+
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
477+
if (codes_ == -1).all():
455478
raise ValueError(
456479
"Failed to group data. Are you grouping by a variable that is all NaN?"
457480
)
458481
codes = self.group.copy(data=codes_)
459-
group_indices = group_indices
460482
unique_coord = IndexVariable(
461483
self.group.name, unique_values, attrs=self.group.attrs
462484
)
463485
full_index = unique_coord
464486

465-
return codes, group_indices, unique_coord, full_index
487+
return EncodedGroups(
488+
codes=codes, full_index=full_index, unique_coord=unique_coord
489+
)
466490

467-
def _factorize_dummy(self) -> T_FactorizeOut:
491+
def _factorize_dummy(self) -> EncodedGroups:
468492
size = self.group.size
469493
# no need to factorize
470494
# use slices to do views instead of fancy indexing
@@ -479,8 +503,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
479503
full_index = IndexVariable(
480504
self.group.name, unique_coord.values, self.group.attrs
481505
)
482-
483-
return codes, group_indices, unique_coord, full_index
506+
return EncodedGroups(
507+
codes=codes,
508+
group_indices=group_indices,
509+
full_index=full_index,
510+
unique_coord=unique_coord,
511+
)
484512

485513

486514
@dataclass
@@ -494,7 +522,7 @@ def __post_init__(self) -> None:
494522
if duck_array_ops.isnull(self.bins).all():
495523
raise ValueError("All bin edges are NaN.")
496524

497-
def factorize(self, group) -> T_FactorizeOut:
525+
def factorize(self, group) -> EncodedGroups:
498526
from xarray.core.dataarray import DataArray
499527

500528
data = group.data
@@ -512,11 +540,7 @@ def factorize(self, group) -> T_FactorizeOut:
512540
full_index = binned.categories
513541
uniques = np.sort(pd.unique(binned_codes))
514542
unique_values = full_index[uniques[uniques != -1]]
515-
group_indices = [
516-
g for g in _codes_to_groups(binned_codes, len(full_index)) if g
517-
]
518-
519-
if len(group_indices) == 0:
543+
if (binned_codes == -1).all():
520544
raise ValueError(
521545
f"None of the data falls within bins with edges {self.bins!r}"
522546
)
@@ -525,7 +549,9 @@ def factorize(self, group) -> T_FactorizeOut:
525549
binned_codes, getattr(group, "coords", None), name=new_dim_name
526550
)
527551
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
528-
return codes, group_indices, unique_coord, full_index
552+
return EncodedGroups(
553+
codes=codes, full_index=full_index, unique_coord=unique_coord
554+
)
529555

530556

531557
@dataclass
@@ -628,7 +654,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
628654
_apply_loffset(self.loffset, first_items)
629655
return first_items, codes
630656

631-
def factorize(self, group) -> T_FactorizeOut:
657+
def factorize(self, group) -> EncodedGroups:
632658
self._init_properties(group)
633659
full_index, first_items, codes_ = self._get_index_and_items()
634660
sbins = first_items.values.astype(np.int64)
@@ -640,7 +666,12 @@ def factorize(self, group) -> T_FactorizeOut:
640666
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
641667
codes = group.copy(data=codes_)
642668

643-
return codes, group_indices, unique_coord, full_index
669+
return EncodedGroups(
670+
codes=codes,
671+
group_indices=group_indices,
672+
full_index=full_index,
673+
unique_coord=unique_coord,
674+
)
644675

645676

646677
def _validate_groupby_squeeze(squeeze: bool | None) -> None:

0 commit comments

Comments
 (0)