60
60
GroupKey = Any
61
61
GroupIndex = Union [int , slice , list [int ]]
62
62
T_GroupIndices = list [GroupIndex ]
63
- T_FactorizeOut = tuple [
64
- DataArray , T_GroupIndices , Union [pd .Index , "_DummyGroup" ], pd .Index
65
- ]
66
63
67
64
68
65
def check_reduce_dims (reduce_dims , dimensions ):
@@ -98,7 +95,7 @@ def _maybe_squeeze_indices(
98
95
99
96
def unique_value_groups (
100
97
ar , sort : bool = True
101
- ) -> tuple [np .ndarray | pd .Index , T_GroupIndices , np .ndarray ]:
98
+ ) -> tuple [np .ndarray | pd .Index , np .ndarray ]:
102
99
"""Group an array by its unique values.
103
100
104
101
Parameters
@@ -119,11 +116,11 @@ def unique_value_groups(
119
116
inverse , values = pd .factorize (ar , sort = sort )
120
117
if isinstance (values , pd .MultiIndex ):
121
118
values .names = ar .names
122
- groups = _codes_to_groups (inverse , len (values ))
123
- return values , groups , inverse
119
+ return values , inverse
124
120
125
121
126
- def _codes_to_groups (inverse : np .ndarray , N : int ) -> T_GroupIndices :
122
+ def _codes_to_group_indices (inverse : np .ndarray , N : int ) -> T_GroupIndices :
123
+ assert inverse .ndim == 1
127
124
groups : T_GroupIndices = [[] for _ in range (N )]
128
125
for n , g in enumerate (inverse ):
129
126
if g >= 0 :
@@ -356,7 +353,7 @@ def can_squeeze(self) -> bool:
356
353
return False
357
354
358
355
@abstractmethod
359
- def factorize (self , group ) -> T_FactorizeOut :
356
+ def factorize (self , group ) -> EncodedGroups :
360
357
"""
361
358
Takes the group, and creates intermediates necessary for GroupBy.
362
359
These intermediates are
@@ -378,6 +375,27 @@ class Resampler(Grouper):
378
375
pass
379
376
380
377
378
+ @dataclass
379
+ class EncodedGroups :
380
+ """
381
+ Dataclass for storing intermediate values for GroupBy operation.
382
+ Returned by factorize method on Grouper objects.
383
+
384
+ Parameters
385
+ ----------
386
+ codes: integer codes for each group
387
+ full_index: pandas Index for the group coordinate
388
+ group_indices: optional, List of indices of array elements belonging
389
+ to each group. Inferred if not provided.
390
+ unique_coord: Unique group values present in dataset. Inferred if not provided
391
+ """
392
+
393
+ codes : DataArray
394
+ full_index : pd .Index
395
+ group_indices : T_GroupIndices | None = field (default = None )
396
+ unique_coord : IndexVariable | _DummyGroup | None = field (default = None )
397
+
398
+
381
399
@dataclass
382
400
class ResolvedGrouper (Generic [T_DataWithCoords ]):
383
401
"""
@@ -397,11 +415,11 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
397
415
group : T_Group
398
416
obj : T_DataWithCoords
399
417
400
- # Defined by factorize:
418
+ # returned by factorize:
401
419
codes : DataArray = field (init = False )
420
+ full_index : pd .Index = field (init = False )
402
421
group_indices : T_GroupIndices = field (init = False )
403
422
unique_coord : IndexVariable | _DummyGroup = field (init = False )
404
- full_index : pd .Index = field (init = False )
405
423
406
424
# _ensure_1d:
407
425
group1d : T_Group = field (init = False )
@@ -445,12 +463,24 @@ def dims(self):
445
463
return self .group1d .dims
446
464
447
465
def factorize (self ) -> None :
448
- (
449
- self .codes ,
450
- self .group_indices ,
451
- self .unique_coord ,
452
- self .full_index ,
453
- ) = self .grouper .factorize (self .group1d )
466
+ encoded = self .grouper .factorize (self .group1d )
467
+
468
+ self .codes = encoded .codes
469
+ self .full_index = encoded .full_index
470
+
471
+ if encoded .group_indices is not None :
472
+ self .group_indices = encoded .group_indices
473
+ else :
474
+ self .group_indices = [
475
+ g
476
+ for g in _codes_to_group_indices (self .codes .data , len (self .full_index ))
477
+ if g
478
+ ]
479
+ if encoded .unique_coord is None :
480
+ # TODO
481
+ raise NotImplementedError
482
+ else :
483
+ self .unique_coord = encoded .unique_coord
454
484
455
485
456
486
@dataclass
@@ -477,34 +507,33 @@ def can_squeeze(self) -> bool:
477
507
is_dimension = self .group .dims == (self .group .name ,)
478
508
return is_dimension and self .is_unique_and_monotonic
479
509
480
- def factorize (self , group1d ) -> T_FactorizeOut :
510
+ def factorize (self , group1d ) -> EncodedGroups :
481
511
self .group = group1d
482
512
483
513
if self .can_squeeze :
484
514
return self ._factorize_dummy ()
485
515
else :
486
516
return self ._factorize_unique ()
487
517
488
- def _factorize_unique (self ) -> T_FactorizeOut :
518
+ def _factorize_unique (self ) -> EncodedGroups :
489
519
# look through group to find the unique values
490
520
sort = not isinstance (self .group_as_index , pd .MultiIndex )
491
- unique_values , group_indices , codes_ = unique_value_groups (
492
- self .group_as_index , sort = sort
493
- )
494
- if len (group_indices ) == 0 :
521
+ unique_values , codes_ = unique_value_groups (self .group_as_index , sort = sort )
522
+ if (codes_ == - 1 ).all ():
495
523
raise ValueError (
496
524
"Failed to group data. Are you grouping by a variable that is all NaN?"
497
525
)
498
526
codes = self .group .copy (data = codes_ )
499
- group_indices = group_indices
500
527
unique_coord = IndexVariable (
501
528
self .group .name , unique_values , attrs = self .group .attrs
502
529
)
503
530
full_index = unique_coord
504
531
505
- return codes , group_indices , unique_coord , full_index
532
+ return EncodedGroups (
533
+ codes = codes , full_index = full_index , unique_coord = unique_coord
534
+ )
506
535
507
- def _factorize_dummy (self ) -> T_FactorizeOut :
536
+ def _factorize_dummy (self ) -> EncodedGroups :
508
537
size = self .group .size
509
538
# no need to factorize
510
539
# use slices to do views instead of fancy indexing
@@ -519,8 +548,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
519
548
full_index = IndexVariable (
520
549
self .group .name , unique_coord .values , self .group .attrs
521
550
)
522
-
523
- return codes , group_indices , unique_coord , full_index
551
+ return EncodedGroups (
552
+ codes = codes ,
553
+ group_indices = group_indices ,
554
+ full_index = full_index ,
555
+ unique_coord = unique_coord ,
556
+ )
524
557
525
558
526
559
@dataclass
@@ -536,7 +569,7 @@ def __post_init__(self) -> None:
536
569
if duck_array_ops .isnull (self .bins ).all ():
537
570
raise ValueError ("All bin edges are NaN." )
538
571
539
- def factorize (self , group ) -> T_FactorizeOut :
572
+ def factorize (self , group ) -> EncodedGroups :
540
573
from xarray .core .dataarray import DataArray
541
574
542
575
data = group .data
@@ -554,11 +587,7 @@ def factorize(self, group) -> T_FactorizeOut:
554
587
full_index = binned .categories
555
588
uniques = np .sort (pd .unique (binned_codes ))
556
589
unique_values = full_index [uniques [uniques != - 1 ]]
557
- group_indices = [
558
- g for g in _codes_to_groups (binned_codes , len (full_index )) if g
559
- ]
560
-
561
- if len (group_indices ) == 0 :
590
+ if (binned_codes == - 1 ).all ():
562
591
raise ValueError (
563
592
f"None of the data falls within bins with edges { self .bins !r} "
564
593
)
@@ -567,7 +596,9 @@ def factorize(self, group) -> T_FactorizeOut:
567
596
binned_codes , getattr (group , "coords" , None ), name = new_dim_name
568
597
)
569
598
unique_coord = IndexVariable (new_dim_name , pd .Index (unique_values ), group .attrs )
570
- return codes , group_indices , unique_coord , full_index
599
+ return EncodedGroups (
600
+ codes = codes , full_index = full_index , unique_coord = unique_coord
601
+ )
571
602
572
603
573
604
@dataclass
@@ -672,7 +703,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
672
703
_apply_loffset (self .loffset , first_items )
673
704
return first_items , codes
674
705
675
- def factorize (self , group ) -> T_FactorizeOut :
706
+ def factorize (self , group ) -> EncodedGroups :
676
707
self ._init_properties (group )
677
708
full_index , first_items , codes_ = self ._get_index_and_items ()
678
709
sbins = first_items .values .astype (np .int64 )
@@ -684,7 +715,12 @@ def factorize(self, group) -> T_FactorizeOut:
684
715
unique_coord = IndexVariable (group .name , first_items .index , group .attrs )
685
716
codes = group .copy (data = codes_ )
686
717
687
- return codes , group_indices , unique_coord , full_index
718
+ return EncodedGroups (
719
+ codes = codes ,
720
+ group_indices = group_indices ,
721
+ full_index = full_index ,
722
+ unique_coord = unique_coord ,
723
+ )
688
724
689
725
690
726
def _validate_groupby_squeeze (squeeze : bool | None ) -> None :
0 commit comments