Skip to content

Commit 734c9c2

Browse files
committed
Return dataclass from factorize
1 parent 8468be0 commit 734c9c2

File tree

1 file changed

+74
-36
lines changed

1 file changed

+74
-36
lines changed

xarray/core/groupby.py

Lines changed: 74 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@
6060
GroupKey = Any
6161
GroupIndex = Union[int, slice, list[int]]
6262
T_GroupIndices = list[GroupIndex]
63-
T_FactorizeOut = tuple[
64-
DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index
65-
]
6663

6764

6865
def check_reduce_dims(reduce_dims, dimensions):
@@ -98,7 +95,7 @@ def _maybe_squeeze_indices(
9895

9996
def unique_value_groups(
10097
ar, sort: bool = True
101-
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
98+
) -> tuple[np.ndarray | pd.Index, np.ndarray]:
10299
"""Group an array by its unique values.
103100
104101
Parameters
@@ -119,11 +116,11 @@ def unique_value_groups(
119116
inverse, values = pd.factorize(ar, sort=sort)
120117
if isinstance(values, pd.MultiIndex):
121118
values.names = ar.names
122-
groups = _codes_to_groups(inverse, len(values))
123-
return values, groups, inverse
119+
return values, inverse
124120

125121

126-
def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices:
122+
def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices:
123+
assert inverse.ndim == 1
127124
groups: T_GroupIndices = [[] for _ in range(N)]
128125
for n, g in enumerate(inverse):
129126
if g >= 0:
@@ -356,7 +353,7 @@ def can_squeeze(self) -> bool:
356353
return False
357354

358355
@abstractmethod
359-
def factorize(self, group) -> T_FactorizeOut:
356+
def factorize(self, group) -> EncodedGroups:
360357
"""
361358
Takes the group, and creates intermediates necessary for GroupBy.
362359
These intermediates are
@@ -378,6 +375,27 @@ class Resampler(Grouper):
378375
pass
379376

380377

378+
@dataclass
379+
class EncodedGroups:
380+
"""
381+
Dataclass for storing intermediate values for GroupBy operation.
382+
Returned by factorize method on Grouper objects.
383+
384+
Parameters
385+
----------
386+
codes: integer codes for each group
387+
full_index: pandas Index for the group coordinate
388+
group_indices: optional, List of indices of array elements belonging
389+
to each group. Inferred if not provided.
390+
unique_coord: Unique group values present in dataset. Inferred if not provided
391+
"""
392+
393+
codes: DataArray
394+
full_index: pd.Index
395+
group_indices: T_GroupIndices | None = field(default=None)
396+
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
397+
398+
381399
@dataclass
382400
class ResolvedGrouper(Generic[T_DataWithCoords]):
383401
"""
@@ -397,11 +415,11 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
397415
group: T_Group
398416
obj: T_DataWithCoords
399417

400-
# Defined by factorize:
418+
# returned by factorize:
401419
codes: DataArray = field(init=False)
420+
full_index: pd.Index = field(init=False)
402421
group_indices: T_GroupIndices = field(init=False)
403422
unique_coord: IndexVariable | _DummyGroup = field(init=False)
404-
full_index: pd.Index = field(init=False)
405423

406424
# _ensure_1d:
407425
group1d: T_Group = field(init=False)
@@ -445,12 +463,26 @@ def dims(self):
445463
return self.group1d.dims
446464

447465
def factorize(self) -> None:
448-
(
449-
self.codes,
450-
self.group_indices,
451-
self.unique_coord,
452-
self.full_index,
453-
) = self.grouper.factorize(self.group1d)
466+
encoded = self.grouper.factorize(self.group1d)
467+
468+
self.codes = encoded.codes
469+
self.full_index = encoded.full_index
470+
471+
if encoded.group_indices is not None:
472+
self.group_indices = encoded.group_indices
473+
else:
474+
self.group_indices = [
475+
g
476+
for g in _codes_to_group_indices(self.codes.data, len(self.full_index))
477+
if g
478+
]
479+
if encoded.unique_coord is None:
480+
unique_values = self.full_index[np.unique(encoded.codes)]
481+
self.unique_coord = IndexVariable(
482+
self.group.name, unique_values, attrs=self.group.attrs
483+
)
484+
else:
485+
self.unique_coord = encoded.unique_coord
454486

455487

456488
@dataclass
@@ -477,34 +509,33 @@ def can_squeeze(self) -> bool:
477509
is_dimension = self.group.dims == (self.group.name,)
478510
return is_dimension and self.is_unique_and_monotonic
479511

480-
def factorize(self, group1d) -> T_FactorizeOut:
512+
def factorize(self, group1d) -> EncodedGroups:
481513
self.group = group1d
482514

483515
if self.can_squeeze:
484516
return self._factorize_dummy()
485517
else:
486518
return self._factorize_unique()
487519

488-
def _factorize_unique(self) -> T_FactorizeOut:
520+
def _factorize_unique(self) -> EncodedGroups:
489521
# look through group to find the unique values
490522
sort = not isinstance(self.group_as_index, pd.MultiIndex)
491-
unique_values, group_indices, codes_ = unique_value_groups(
492-
self.group_as_index, sort=sort
493-
)
494-
if len(group_indices) == 0:
523+
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
524+
if (codes_ == -1).all():
495525
raise ValueError(
496526
"Failed to group data. Are you grouping by a variable that is all NaN?"
497527
)
498528
codes = self.group.copy(data=codes_)
499-
group_indices = group_indices
500529
unique_coord = IndexVariable(
501530
self.group.name, unique_values, attrs=self.group.attrs
502531
)
503532
full_index = unique_coord
504533

505-
return codes, group_indices, unique_coord, full_index
534+
return EncodedGroups(
535+
codes=codes, full_index=full_index, unique_coord=unique_coord
536+
)
506537

507-
def _factorize_dummy(self) -> T_FactorizeOut:
538+
def _factorize_dummy(self) -> EncodedGroups:
508539
size = self.group.size
509540
# no need to factorize
510541
# use slices to do views instead of fancy indexing
@@ -519,8 +550,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
519550
full_index = IndexVariable(
520551
self.group.name, unique_coord.values, self.group.attrs
521552
)
522-
523-
return codes, group_indices, unique_coord, full_index
553+
return EncodedGroups(
554+
codes=codes,
555+
group_indices=group_indices,
556+
full_index=full_index,
557+
unique_coord=unique_coord,
558+
)
524559

525560

526561
@dataclass
@@ -536,7 +571,7 @@ def __post_init__(self) -> None:
536571
if duck_array_ops.isnull(self.bins).all():
537572
raise ValueError("All bin edges are NaN.")
538573

539-
def factorize(self, group) -> T_FactorizeOut:
574+
def factorize(self, group) -> EncodedGroups:
540575
from xarray.core.dataarray import DataArray
541576

542577
data = group.data
@@ -554,11 +589,7 @@ def factorize(self, group) -> T_FactorizeOut:
554589
full_index = binned.categories
555590
uniques = np.sort(pd.unique(binned_codes))
556591
unique_values = full_index[uniques[uniques != -1]]
557-
group_indices = [
558-
g for g in _codes_to_groups(binned_codes, len(full_index)) if g
559-
]
560-
561-
if len(group_indices) == 0:
592+
if (binned_codes == -1).all():
562593
raise ValueError(
563594
f"None of the data falls within bins with edges {self.bins!r}"
564595
)
@@ -567,7 +598,9 @@ def factorize(self, group) -> T_FactorizeOut:
567598
binned_codes, getattr(group, "coords", None), name=new_dim_name
568599
)
569600
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
570-
return codes, group_indices, unique_coord, full_index
601+
return EncodedGroups(
602+
codes=codes, full_index=full_index, unique_coord=unique_coord
603+
)
571604

572605

573606
@dataclass
@@ -672,7 +705,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
672705
_apply_loffset(self.loffset, first_items)
673706
return first_items, codes
674707

675-
def factorize(self, group) -> T_FactorizeOut:
708+
def factorize(self, group) -> EncodedGroups:
676709
self._init_properties(group)
677710
full_index, first_items, codes_ = self._get_index_and_items()
678711
sbins = first_items.values.astype(np.int64)
@@ -684,7 +717,12 @@ def factorize(self, group) -> T_FactorizeOut:
684717
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
685718
codes = group.copy(data=codes_)
686719

687-
return codes, group_indices, unique_coord, full_index
720+
return EncodedGroups(
721+
codes=codes,
722+
group_indices=group_indices,
723+
full_index=full_index,
724+
unique_coord=unique_coord,
725+
)
688726

689727

690728
def _validate_groupby_squeeze(squeeze: bool | None) -> None:

0 commit comments

Comments
 (0)