Skip to content

Commit 6072586

Browse files
committed
Return dataclass from factorize
1 parent 8468be0 commit 6072586

File tree

1 file changed

+72
-36
lines changed

1 file changed

+72
-36
lines changed

xarray/core/groupby.py

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@
6060
GroupKey = Any
6161
GroupIndex = Union[int, slice, list[int]]
6262
T_GroupIndices = list[GroupIndex]
63-
T_FactorizeOut = tuple[
64-
DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index
65-
]
6663

6764

6865
def check_reduce_dims(reduce_dims, dimensions):
@@ -98,7 +95,7 @@ def _maybe_squeeze_indices(
9895

9996
def unique_value_groups(
10097
ar, sort: bool = True
101-
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
98+
) -> tuple[np.ndarray | pd.Index, np.ndarray]:
10299
"""Group an array by its unique values.
103100
104101
Parameters
@@ -119,11 +116,11 @@ def unique_value_groups(
119116
inverse, values = pd.factorize(ar, sort=sort)
120117
if isinstance(values, pd.MultiIndex):
121118
values.names = ar.names
122-
groups = _codes_to_groups(inverse, len(values))
123-
return values, groups, inverse
119+
return values, inverse
124120

125121

126-
def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices:
122+
def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices:
123+
assert inverse.ndim == 1
127124
groups: T_GroupIndices = [[] for _ in range(N)]
128125
for n, g in enumerate(inverse):
129126
if g >= 0:
@@ -356,7 +353,7 @@ def can_squeeze(self) -> bool:
356353
return False
357354

358355
@abstractmethod
359-
def factorize(self, group) -> T_FactorizeOut:
356+
def factorize(self, group) -> EncodedGroups:
360357
"""
361358
Takes the group, and creates intermediates necessary for GroupBy.
362359
These intermediates are
@@ -378,6 +375,27 @@ class Resampler(Grouper):
378375
pass
379376

380377

378+
@dataclass
379+
class EncodedGroups:
380+
"""
381+
Dataclass for storing intermediate values for GroupBy operation.
382+
Returned by factorize method on Grouper objects.
383+
384+
Parameters
385+
----------
386+
codes: integer codes for each group
387+
full_index: pandas Index for the group coordinate
388+
group_indices: optional, List of indices of array elements belonging
389+
to each group. Inferred if not provided.
390+
unique_coord: Unique group values present in dataset. Inferred if not provided
391+
"""
392+
393+
codes: DataArray
394+
full_index: pd.Index
395+
group_indices: T_GroupIndices | None = field(default=None)
396+
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
397+
398+
381399
@dataclass
382400
class ResolvedGrouper(Generic[T_DataWithCoords]):
383401
"""
@@ -397,11 +415,11 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
397415
group: T_Group
398416
obj: T_DataWithCoords
399417

400-
# Defined by factorize:
418+
# returned by factorize:
401419
codes: DataArray = field(init=False)
420+
full_index: pd.Index = field(init=False)
402421
group_indices: T_GroupIndices = field(init=False)
403422
unique_coord: IndexVariable | _DummyGroup = field(init=False)
404-
full_index: pd.Index = field(init=False)
405423

406424
# _ensure_1d:
407425
group1d: T_Group = field(init=False)
@@ -445,12 +463,24 @@ def dims(self):
445463
return self.group1d.dims
446464

447465
def factorize(self) -> None:
448-
(
449-
self.codes,
450-
self.group_indices,
451-
self.unique_coord,
452-
self.full_index,
453-
) = self.grouper.factorize(self.group1d)
466+
encoded = self.grouper.factorize(self.group1d)
467+
468+
self.codes = encoded.codes
469+
self.full_index = encoded.full_index
470+
471+
if encoded.group_indices is not None:
472+
self.group_indices = encoded.group_indices
473+
else:
474+
self.group_indices = [
475+
g
476+
for g in _codes_to_group_indices(self.codes.data, len(self.full_index))
477+
if g
478+
]
479+
if encoded.unique_coord is None:
480+
# TODO
481+
raise NotImplementedError
482+
else:
483+
self.unique_coord = encoded.unique_coord
454484

455485

456486
@dataclass
@@ -477,34 +507,33 @@ def can_squeeze(self) -> bool:
477507
is_dimension = self.group.dims == (self.group.name,)
478508
return is_dimension and self.is_unique_and_monotonic
479509

480-
def factorize(self, group1d) -> T_FactorizeOut:
510+
def factorize(self, group1d) -> EncodedGroups:
481511
self.group = group1d
482512

483513
if self.can_squeeze:
484514
return self._factorize_dummy()
485515
else:
486516
return self._factorize_unique()
487517

488-
def _factorize_unique(self) -> T_FactorizeOut:
518+
def _factorize_unique(self) -> EncodedGroups:
489519
# look through group to find the unique values
490520
sort = not isinstance(self.group_as_index, pd.MultiIndex)
491-
unique_values, group_indices, codes_ = unique_value_groups(
492-
self.group_as_index, sort=sort
493-
)
494-
if len(group_indices) == 0:
521+
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
522+
if (codes_ == -1).all():
495523
raise ValueError(
496524
"Failed to group data. Are you grouping by a variable that is all NaN?"
497525
)
498526
codes = self.group.copy(data=codes_)
499-
group_indices = group_indices
500527
unique_coord = IndexVariable(
501528
self.group.name, unique_values, attrs=self.group.attrs
502529
)
503530
full_index = unique_coord
504531

505-
return codes, group_indices, unique_coord, full_index
532+
return EncodedGroups(
533+
codes=codes, full_index=full_index, unique_coord=unique_coord
534+
)
506535

507-
def _factorize_dummy(self) -> T_FactorizeOut:
536+
def _factorize_dummy(self) -> EncodedGroups:
508537
size = self.group.size
509538
# no need to factorize
510539
# use slices to do views instead of fancy indexing
@@ -519,8 +548,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
519548
full_index = IndexVariable(
520549
self.group.name, unique_coord.values, self.group.attrs
521550
)
522-
523-
return codes, group_indices, unique_coord, full_index
551+
return EncodedGroups(
552+
codes=codes,
553+
group_indices=group_indices,
554+
full_index=full_index,
555+
unique_coord=unique_coord,
556+
)
524557

525558

526559
@dataclass
@@ -536,7 +569,7 @@ def __post_init__(self) -> None:
536569
if duck_array_ops.isnull(self.bins).all():
537570
raise ValueError("All bin edges are NaN.")
538571

539-
def factorize(self, group) -> T_FactorizeOut:
572+
def factorize(self, group) -> EncodedGroups:
540573
from xarray.core.dataarray import DataArray
541574

542575
data = group.data
@@ -554,11 +587,7 @@ def factorize(self, group) -> T_FactorizeOut:
554587
full_index = binned.categories
555588
uniques = np.sort(pd.unique(binned_codes))
556589
unique_values = full_index[uniques[uniques != -1]]
557-
group_indices = [
558-
g for g in _codes_to_groups(binned_codes, len(full_index)) if g
559-
]
560-
561-
if len(group_indices) == 0:
590+
if (binned_codes == -1).all():
562591
raise ValueError(
563592
f"None of the data falls within bins with edges {self.bins!r}"
564593
)
@@ -567,7 +596,9 @@ def factorize(self, group) -> T_FactorizeOut:
567596
binned_codes, getattr(group, "coords", None), name=new_dim_name
568597
)
569598
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
570-
return codes, group_indices, unique_coord, full_index
599+
return EncodedGroups(
600+
codes=codes, full_index=full_index, unique_coord=unique_coord
601+
)
571602

572603

573604
@dataclass
@@ -672,7 +703,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
672703
_apply_loffset(self.loffset, first_items)
673704
return first_items, codes
674705

675-
def factorize(self, group) -> T_FactorizeOut:
706+
def factorize(self, group) -> EncodedGroups:
676707
self._init_properties(group)
677708
full_index, first_items, codes_ = self._get_index_and_items()
678709
sbins = first_items.values.astype(np.int64)
@@ -684,7 +715,12 @@ def factorize(self, group) -> T_FactorizeOut:
684715
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
685716
codes = group.copy(data=codes_)
686717

687-
return codes, group_indices, unique_coord, full_index
718+
return EncodedGroups(
719+
codes=codes,
720+
group_indices=group_indices,
721+
full_index=full_index,
722+
unique_coord=unique_coord,
723+
)
688724

689725

690726
def _validate_groupby_squeeze(squeeze: bool | None) -> None:

0 commit comments

Comments
 (0)