60
60
GroupKey = Any
61
61
GroupIndex = Union [int , slice , list [int ]]
62
62
T_GroupIndices = list [GroupIndex ]
63
- T_FactorizeOut = tuple [
64
- DataArray , T_GroupIndices , Union [pd .Index , "_DummyGroup" ], pd .Index
65
- ]
66
63
67
64
68
65
def check_reduce_dims (reduce_dims , dimensions ):
@@ -98,7 +95,7 @@ def _maybe_squeeze_indices(
98
95
99
96
def unique_value_groups (
100
97
ar , sort : bool = True
101
- ) -> tuple [np .ndarray | pd .Index , T_GroupIndices , np .ndarray ]:
98
+ ) -> tuple [np .ndarray | pd .Index , np .ndarray ]:
102
99
"""Group an array by its unique values.
103
100
104
101
Parameters
@@ -119,11 +116,11 @@ def unique_value_groups(
119
116
inverse , values = pd .factorize (ar , sort = sort )
120
117
if isinstance (values , pd .MultiIndex ):
121
118
values .names = ar .names
122
- groups = _codes_to_groups (inverse , len (values ))
123
- return values , groups , inverse
119
+ return values , inverse
124
120
125
121
126
- def _codes_to_groups (inverse : np .ndarray , N : int ) -> T_GroupIndices :
122
+ def _codes_to_group_indices (inverse : np .ndarray , N : int ) -> T_GroupIndices :
123
+ assert inverse .ndim == 1
127
124
groups : T_GroupIndices = [[] for _ in range (N )]
128
125
for n , g in enumerate (inverse ):
129
126
if g >= 0 :
@@ -356,7 +353,7 @@ def can_squeeze(self) -> bool:
356
353
return False
357
354
358
355
@abstractmethod
359
- def factorize (self , group ) -> T_FactorizeOut :
356
+ def factorize (self , group ) -> EncodedGroups :
360
357
"""
361
358
Takes the group, and creates intermediates necessary for GroupBy.
362
359
These intermediates are
@@ -378,6 +375,27 @@ class Resampler(Grouper):
378
375
pass
379
376
380
377
378
+ @dataclass
379
+ class EncodedGroups :
380
+ """
381
+ Dataclass for storing intermediate values for GroupBy operation.
382
+ Returned by factorize method on Grouper objects.
383
+
384
+ Parameters
385
+ ----------
386
+ codes: integer codes for each group
387
+ full_index: pandas Index for the group coordinate
388
+ group_indices: optional, List of indices of array elements belonging
389
+ to each group. Inferred if not provided.
390
+ unique_coord: Unique group values present in dataset. Inferred if not provided
391
+ """
392
+
393
+ codes : DataArray
394
+ full_index : pd .Index
395
+ group_indices : T_GroupIndices | None = field (default = None )
396
+ unique_coord : IndexVariable | _DummyGroup | None = field (default = None )
397
+
398
+
381
399
@dataclass
382
400
class ResolvedGrouper (Generic [T_DataWithCoords ]):
383
401
"""
@@ -397,11 +415,11 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
397
415
group : T_Group
398
416
obj : T_DataWithCoords
399
417
400
- # Defined by factorize:
418
+ # returned by factorize:
401
419
codes : DataArray = field (init = False )
420
+ full_index : pd .Index = field (init = False )
402
421
group_indices : T_GroupIndices = field (init = False )
403
422
unique_coord : IndexVariable | _DummyGroup = field (init = False )
404
- full_index : pd .Index = field (init = False )
405
423
406
424
# _ensure_1d:
407
425
group1d : T_Group = field (init = False )
@@ -445,12 +463,26 @@ def dims(self):
445
463
return self .group1d .dims
446
464
447
465
def factorize (self ) -> None :
448
- (
449
- self .codes ,
450
- self .group_indices ,
451
- self .unique_coord ,
452
- self .full_index ,
453
- ) = self .grouper .factorize (self .group1d )
466
+ encoded = self .grouper .factorize (self .group1d )
467
+
468
+ self .codes = encoded .codes
469
+ self .full_index = encoded .full_index
470
+
471
+ if encoded .group_indices is not None :
472
+ self .group_indices = encoded .group_indices
473
+ else :
474
+ self .group_indices = [
475
+ g
476
+ for g in _codes_to_group_indices (self .codes .data , len (self .full_index ))
477
+ if g
478
+ ]
479
+ if encoded .unique_coord is None :
480
+ unique_values = self .full_index [np .unique (encoded .codes )]
481
+ self .unique_coord = IndexVariable (
482
+ self .group .name , unique_values , attrs = self .group .attrs
483
+ )
484
+ else :
485
+ self .unique_coord = encoded .unique_coord
454
486
455
487
456
488
@dataclass
@@ -477,34 +509,33 @@ def can_squeeze(self) -> bool:
477
509
is_dimension = self .group .dims == (self .group .name ,)
478
510
return is_dimension and self .is_unique_and_monotonic
479
511
480
- def factorize (self , group1d ) -> T_FactorizeOut :
512
+ def factorize (self , group1d ) -> EncodedGroups :
481
513
self .group = group1d
482
514
483
515
if self .can_squeeze :
484
516
return self ._factorize_dummy ()
485
517
else :
486
518
return self ._factorize_unique ()
487
519
488
- def _factorize_unique (self ) -> T_FactorizeOut :
520
+ def _factorize_unique (self ) -> EncodedGroups :
489
521
# look through group to find the unique values
490
522
sort = not isinstance (self .group_as_index , pd .MultiIndex )
491
- unique_values , group_indices , codes_ = unique_value_groups (
492
- self .group_as_index , sort = sort
493
- )
494
- if len (group_indices ) == 0 :
523
+ unique_values , codes_ = unique_value_groups (self .group_as_index , sort = sort )
524
+ if (codes_ == - 1 ).all ():
495
525
raise ValueError (
496
526
"Failed to group data. Are you grouping by a variable that is all NaN?"
497
527
)
498
528
codes = self .group .copy (data = codes_ )
499
- group_indices = group_indices
500
529
unique_coord = IndexVariable (
501
530
self .group .name , unique_values , attrs = self .group .attrs
502
531
)
503
532
full_index = unique_coord
504
533
505
- return codes , group_indices , unique_coord , full_index
534
+ return EncodedGroups (
535
+ codes = codes , full_index = full_index , unique_coord = unique_coord
536
+ )
506
537
507
- def _factorize_dummy (self ) -> T_FactorizeOut :
538
+ def _factorize_dummy (self ) -> EncodedGroups :
508
539
size = self .group .size
509
540
# no need to factorize
510
541
# use slices to do views instead of fancy indexing
@@ -519,8 +550,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
519
550
full_index = IndexVariable (
520
551
self .group .name , unique_coord .values , self .group .attrs
521
552
)
522
-
523
- return codes , group_indices , unique_coord , full_index
553
+ return EncodedGroups (
554
+ codes = codes ,
555
+ group_indices = group_indices ,
556
+ full_index = full_index ,
557
+ unique_coord = unique_coord ,
558
+ )
524
559
525
560
526
561
@dataclass
@@ -536,7 +571,7 @@ def __post_init__(self) -> None:
536
571
if duck_array_ops .isnull (self .bins ).all ():
537
572
raise ValueError ("All bin edges are NaN." )
538
573
539
- def factorize (self , group ) -> T_FactorizeOut :
574
+ def factorize (self , group ) -> EncodedGroups :
540
575
from xarray .core .dataarray import DataArray
541
576
542
577
data = group .data
@@ -554,20 +589,14 @@ def factorize(self, group) -> T_FactorizeOut:
554
589
full_index = binned .categories
555
590
uniques = np .sort (pd .unique (binned_codes ))
556
591
unique_values = full_index [uniques [uniques != - 1 ]]
557
- group_indices = [
558
- g for g in _codes_to_groups (binned_codes , len (full_index )) if g
559
- ]
560
-
561
- if len (group_indices ) == 0 :
562
- raise ValueError (
563
- f"None of the data falls within bins with edges { self .bins !r} "
564
- )
565
592
566
593
codes = DataArray (
567
594
binned_codes , getattr (group , "coords" , None ), name = new_dim_name
568
595
)
569
596
unique_coord = IndexVariable (new_dim_name , pd .Index (unique_values ), group .attrs )
570
- return codes , group_indices , unique_coord , full_index
597
+ return EncodedGroups (
598
+ codes = codes , full_index = full_index , unique_coord = unique_coord
599
+ )
571
600
572
601
573
602
@dataclass
@@ -672,7 +701,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
672
701
_apply_loffset (self .loffset , first_items )
673
702
return first_items , codes
674
703
675
- def factorize (self , group ) -> T_FactorizeOut :
704
+ def factorize (self , group ) -> EncodedGroups :
676
705
self ._init_properties (group )
677
706
full_index , first_items , codes_ = self ._get_index_and_items ()
678
707
sbins = first_items .values .astype (np .int64 )
@@ -684,7 +713,12 @@ def factorize(self, group) -> T_FactorizeOut:
684
713
unique_coord = IndexVariable (group .name , first_items .index , group .attrs )
685
714
codes = group .copy (data = codes_ )
686
715
687
- return codes , group_indices , unique_coord , full_index
716
+ return EncodedGroups (
717
+ codes = codes ,
718
+ group_indices = group_indices ,
719
+ full_index = full_index ,
720
+ unique_coord = unique_coord ,
721
+ )
688
722
689
723
690
724
def _validate_groupby_squeeze (squeeze : bool | None ) -> None :
0 commit comments