6060 GroupKey = Any
6161 GroupIndex = Union [int , slice , list [int ]]
6262 T_GroupIndices = list [GroupIndex ]
63- T_FactorizeOut = tuple [
64- DataArray , T_GroupIndices , Union [pd .Index , "_DummyGroup" ], pd .Index
65- ]
6663
6764
6865def check_reduce_dims (reduce_dims , dimensions ):
@@ -98,7 +95,7 @@ def _maybe_squeeze_indices(
9895
9996def unique_value_groups (
10097 ar , sort : bool = True
101- ) -> tuple [np .ndarray | pd .Index , T_GroupIndices , np .ndarray ]:
98+ ) -> tuple [np .ndarray | pd .Index , np .ndarray ]:
10299 """Group an array by its unique values.
103100
104101 Parameters
@@ -119,11 +116,11 @@ def unique_value_groups(
119116 inverse , values = pd .factorize (ar , sort = sort )
120117 if isinstance (values , pd .MultiIndex ):
121118 values .names = ar .names
122- groups = _codes_to_groups (inverse , len (values ))
123- return values , groups , inverse
119+ return values , inverse
124120
125121
126- def _codes_to_groups (inverse : np .ndarray , N : int ) -> T_GroupIndices :
122+ def _codes_to_group_indices (inverse : np .ndarray , N : int ) -> T_GroupIndices :
123+ assert inverse .ndim == 1
127124 groups : T_GroupIndices = [[] for _ in range (N )]
128125 for n , g in enumerate (inverse ):
129126 if g >= 0 :
@@ -356,7 +353,7 @@ def can_squeeze(self) -> bool:
356353 return False
357354
358355 @abstractmethod
359- def factorize (self , group ) -> T_FactorizeOut :
356+ def factorize (self , group ) -> EncodedGroups :
360357 """
361358 Takes the group, and creates intermediates necessary for GroupBy.
362359 These intermediates are
@@ -378,6 +375,27 @@ class Resampler(Grouper):
378375 pass
379376
380377
378+ @dataclass
379+ class EncodedGroups :
380+ """
381+ Dataclass for storing intermediate values for GroupBy operation.
382+ Returned by factorize method on Grouper objects.
383+
384+ Parameters
385+ ----------
386+ codes: integer codes for each group
387+ full_index: pandas Index for the group coordinate
388+ group_indices: optional, List of indices of array elements belonging
389+ to each group. Inferred if not provided.
390+ unique_coord: Unique group values present in dataset. Inferred if not provided
391+ """
392+
393+ codes : DataArray
394+ full_index : pd .Index
395+ group_indices : T_GroupIndices | None = field (default = None )
396+ unique_coord : IndexVariable | _DummyGroup | None = field (default = None )
397+
398+
381399@dataclass
382400class ResolvedGrouper (Generic [T_DataWithCoords ]):
383401 """
@@ -397,11 +415,11 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
397415 group : T_Group
398416 obj : T_DataWithCoords
399417
400- # Defined by factorize:
418+ # returned by factorize:
401419 codes : DataArray = field (init = False )
420+ full_index : pd .Index = field (init = False )
402421 group_indices : T_GroupIndices = field (init = False )
403422 unique_coord : IndexVariable | _DummyGroup = field (init = False )
404- full_index : pd .Index = field (init = False )
405423
406424 # _ensure_1d:
407425 group1d : T_Group = field (init = False )
@@ -445,12 +463,26 @@ def dims(self):
445463 return self .group1d .dims
446464
447465 def factorize (self ) -> None :
448- (
449- self .codes ,
450- self .group_indices ,
451- self .unique_coord ,
452- self .full_index ,
453- ) = self .grouper .factorize (self .group1d )
466+ encoded = self .grouper .factorize (self .group1d )
467+
468+ self .codes = encoded .codes
469+ self .full_index = encoded .full_index
470+
471+ if encoded .group_indices is not None :
472+ self .group_indices = encoded .group_indices
473+ else :
474+ self .group_indices = [
475+ g
476+ for g in _codes_to_group_indices (self .codes .data , len (self .full_index ))
477+ if g
478+ ]
479+ if encoded .unique_coord is None :
480+ unique_values = self .full_index [np .unique (encoded .codes )]
481+ self .unique_coord = IndexVariable (
482+ self .group .name , unique_values , attrs = self .group .attrs
483+ )
484+ else :
485+ self .unique_coord = encoded .unique_coord
454486
455487
456488@dataclass
@@ -477,34 +509,33 @@ def can_squeeze(self) -> bool:
477509 is_dimension = self .group .dims == (self .group .name ,)
478510 return is_dimension and self .is_unique_and_monotonic
479511
480- def factorize (self , group1d ) -> T_FactorizeOut :
512+ def factorize (self , group1d ) -> EncodedGroups :
481513 self .group = group1d
482514
483515 if self .can_squeeze :
484516 return self ._factorize_dummy ()
485517 else :
486518 return self ._factorize_unique ()
487519
488- def _factorize_unique (self ) -> T_FactorizeOut :
520+ def _factorize_unique (self ) -> EncodedGroups :
489521 # look through group to find the unique values
490522 sort = not isinstance (self .group_as_index , pd .MultiIndex )
491- unique_values , group_indices , codes_ = unique_value_groups (
492- self .group_as_index , sort = sort
493- )
494- if len (group_indices ) == 0 :
523+ unique_values , codes_ = unique_value_groups (self .group_as_index , sort = sort )
524+ if (codes_ == - 1 ).all ():
495525 raise ValueError (
496526 "Failed to group data. Are you grouping by a variable that is all NaN?"
497527 )
498528 codes = self .group .copy (data = codes_ )
499- group_indices = group_indices
500529 unique_coord = IndexVariable (
501530 self .group .name , unique_values , attrs = self .group .attrs
502531 )
503532 full_index = unique_coord
504533
505- return codes , group_indices , unique_coord , full_index
534+ return EncodedGroups (
535+ codes = codes , full_index = full_index , unique_coord = unique_coord
536+ )
506537
507- def _factorize_dummy (self ) -> T_FactorizeOut :
538+ def _factorize_dummy (self ) -> EncodedGroups :
508539 size = self .group .size
509540 # no need to factorize
510541 # use slices to do views instead of fancy indexing
@@ -519,8 +550,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
519550 full_index = IndexVariable (
520551 self .group .name , unique_coord .values , self .group .attrs
521552 )
522-
523- return codes , group_indices , unique_coord , full_index
553+ return EncodedGroups (
554+ codes = codes ,
555+ group_indices = group_indices ,
556+ full_index = full_index ,
557+ unique_coord = unique_coord ,
558+ )
524559
525560
526561@dataclass
@@ -536,7 +571,7 @@ def __post_init__(self) -> None:
536571 if duck_array_ops .isnull (self .bins ).all ():
537572 raise ValueError ("All bin edges are NaN." )
538573
539- def factorize (self , group ) -> T_FactorizeOut :
574+ def factorize (self , group ) -> EncodedGroups :
540575 from xarray .core .dataarray import DataArray
541576
542577 data = group .data
@@ -554,20 +589,14 @@ def factorize(self, group) -> T_FactorizeOut:
554589 full_index = binned .categories
555590 uniques = np .sort (pd .unique (binned_codes ))
556591 unique_values = full_index [uniques [uniques != - 1 ]]
557- group_indices = [
558- g for g in _codes_to_groups (binned_codes , len (full_index )) if g
559- ]
560-
561- if len (group_indices ) == 0 :
562- raise ValueError (
563- f"None of the data falls within bins with edges { self .bins !r} "
564- )
565592
566593 codes = DataArray (
567594 binned_codes , getattr (group , "coords" , None ), name = new_dim_name
568595 )
569596 unique_coord = IndexVariable (new_dim_name , pd .Index (unique_values ), group .attrs )
570- return codes , group_indices , unique_coord , full_index
597+ return EncodedGroups (
598+ codes = codes , full_index = full_index , unique_coord = unique_coord
599+ )
571600
572601
573602@dataclass
@@ -672,7 +701,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
672701 _apply_loffset (self .loffset , first_items )
673702 return first_items , codes
674703
675- def factorize (self , group ) -> T_FactorizeOut :
704+ def factorize (self , group ) -> EncodedGroups :
676705 self ._init_properties (group )
677706 full_index , first_items , codes_ = self ._get_index_and_items ()
678707 sbins = first_items .values .astype (np .int64 )
@@ -684,7 +713,12 @@ def factorize(self, group) -> T_FactorizeOut:
684713 unique_coord = IndexVariable (group .name , first_items .index , group .attrs )
685714 codes = group .copy (data = codes_ )
686715
687- return codes , group_indices , unique_coord , full_index
716+ return EncodedGroups (
717+ codes = codes ,
718+ group_indices = group_indices ,
719+ full_index = full_index ,
720+ unique_coord = unique_coord ,
721+ )
688722
689723
690724def _validate_groupby_squeeze (squeeze : bool | None ) -> None :
0 commit comments