54
54
GroupKey = Any
55
55
GroupIndex = Union [int , slice , list [int ]]
56
56
T_GroupIndices = list [GroupIndex ]
57
- T_FactorizeOut = tuple [
58
- DataArray , T_GroupIndices , Union [pd .Index , "_DummyGroup" ], pd .Index , DataArray
59
- ]
60
57
61
58
62
59
def check_reduce_dims (reduce_dims , dimensions ):
@@ -92,7 +89,7 @@ def _maybe_squeeze_indices(
92
89
93
90
def unique_value_groups (
94
91
ar , sort : bool = True
95
- ) -> tuple [np .ndarray | pd .Index , T_GroupIndices , np .ndarray ]:
92
+ ) -> tuple [np .ndarray | pd .Index , np .ndarray ]:
96
93
"""Group an array by its unique values.
97
94
98
95
Parameters
@@ -113,11 +110,11 @@ def unique_value_groups(
113
110
inverse , values = pd .factorize (ar , sort = sort )
114
111
if isinstance (values , pd .MultiIndex ):
115
112
values .names = ar .names
116
- groups = _codes_to_groups (inverse , len (values ))
117
- return values , groups , inverse
113
+ return values , inverse
118
114
119
115
120
- def _codes_to_groups (inverse : np .ndarray , N : int ) -> T_GroupIndices :
116
+ def _codes_to_group_indices (inverse : np .ndarray , N : int ) -> T_GroupIndices :
117
+ assert inverse .ndim == 1
121
118
groups : T_GroupIndices = [[] for _ in range (N )]
122
119
for n , g in enumerate (inverse ):
123
120
if g >= 0 :
@@ -341,16 +338,35 @@ def _apply_loffset(
341
338
342
339
343
340
@dataclass
344
- class ResolvedGrouper :
341
+ class EncodedGroups :
342
+ """
343
+ Parameters
344
+ ----------
345
+ codes:
346
+ full_index:
347
+ group_indices: optional,
348
+ Inferred if not provided.
349
+ unique_coord:
350
+ Inferred if not provided
351
+ """
352
+
353
+ codes : DataArray
354
+ full_index : pd .Index
355
+ group_indices : T_GroupIndices | None = field (default = None )
356
+ unique_coord : IndexVariable | _DummyGroup | None = field (default = None )
357
+
358
+
359
+ @dataclass
360
+ class ResolvedGrouper (Generic [T_Xarray ]):
345
361
grouper : Grouper
346
362
group : T_Group
347
363
obj : T_Xarray
348
364
349
- # Defined by factorize:
365
+ # returned by factorize:
350
366
codes : DataArray = field (init = False )
367
+ full_index : pd .Index = field (init = False )
351
368
group_indices : T_GroupIndices = field (init = False )
352
369
unique_coord : IndexVariable | _DummyGroup = field (init = False )
353
- full_index : pd .Index = field (init = False )
354
370
355
371
# _ensure_1d:
356
372
group1d : T_Group = field (init = False )
@@ -394,20 +410,29 @@ def dims(self):
394
410
return self .group1d .dims
395
411
396
412
def factorize (self ) -> None :
397
- # This design makes it clear to mypy that
398
- # codes, group_indices, unique_coord, and full_index
399
- # are set by the factorize method on the derived class.
400
- (
401
- self .codes ,
402
- self .group_indices ,
403
- self .unique_coord ,
404
- self .full_index ,
405
- ) = self .grouper .factorize (self .group1d )
413
+ encoded = self .grouper .factorize (self .group1d )
414
+
415
+ self .codes = encoded .codes
416
+ self .full_index = encoded .full_index
417
+
418
+ if encoded .group_indices is not None :
419
+ self .group_indices = encoded .group_indices
420
+ else :
421
+ self .group_indices = [
422
+ g
423
+ for g in _codes_to_group_indices (self .codes .data , len (self .full_index ))
424
+ if g
425
+ ]
426
+ if encoded .unique_coord is None :
427
+ # TODO
428
+ raise NotImplementedError
429
+ else :
430
+ self .unique_coord = encoded .unique_coord
406
431
407
432
408
433
class Grouper (ABC ):
409
434
@abstractmethod
410
- def factorize (self , group ) -> T_FactorizeOut :
435
+ def factorize (self , group : T_Group ) -> EncodedGroups :
411
436
pass
412
437
413
438
@@ -437,34 +462,33 @@ def can_squeeze(self) -> bool:
437
462
is_dimension = self .group .dims == (self .group .name ,)
438
463
return is_dimension and self .is_unique_and_monotonic
439
464
440
- def factorize (self , group1d ) -> T_FactorizeOut :
465
+ def factorize (self , group1d ) -> EncodedGroups :
441
466
self .group = group1d
442
467
443
468
if self .can_squeeze :
444
469
return self ._factorize_dummy ()
445
470
else :
446
471
return self ._factorize_unique ()
447
472
448
- def _factorize_unique (self ) -> T_FactorizeOut :
473
+ def _factorize_unique (self ) -> EncodedGroups :
449
474
# look through group to find the unique values
450
475
sort = not isinstance (self .group_as_index , pd .MultiIndex )
451
- unique_values , group_indices , codes_ = unique_value_groups (
452
- self .group_as_index , sort = sort
453
- )
454
- if len (group_indices ) == 0 :
476
+ unique_values , codes_ = unique_value_groups (self .group_as_index , sort = sort )
477
+ if (codes_ == - 1 ).all ():
455
478
raise ValueError (
456
479
"Failed to group data. Are you grouping by a variable that is all NaN?"
457
480
)
458
481
codes = self .group .copy (data = codes_ )
459
- group_indices = group_indices
460
482
unique_coord = IndexVariable (
461
483
self .group .name , unique_values , attrs = self .group .attrs
462
484
)
463
485
full_index = unique_coord
464
486
465
- return codes , group_indices , unique_coord , full_index
487
+ return EncodedGroups (
488
+ codes = codes , full_index = full_index , unique_coord = unique_coord
489
+ )
466
490
467
- def _factorize_dummy (self ) -> T_FactorizeOut :
491
+ def _factorize_dummy (self ) -> EncodedGroups :
468
492
size = self .group .size
469
493
# no need to factorize
470
494
# use slices to do views instead of fancy indexing
@@ -479,8 +503,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
479
503
full_index = IndexVariable (
480
504
self .group .name , unique_coord .values , self .group .attrs
481
505
)
482
-
483
- return codes , group_indices , unique_coord , full_index
506
+ return EncodedGroups (
507
+ codes = codes ,
508
+ group_indices = group_indices ,
509
+ full_index = full_index ,
510
+ unique_coord = unique_coord ,
511
+ )
484
512
485
513
486
514
@dataclass
@@ -494,7 +522,7 @@ def __post_init__(self) -> None:
494
522
if duck_array_ops .isnull (self .bins ).all ():
495
523
raise ValueError ("All bin edges are NaN." )
496
524
497
- def factorize (self , group ) -> T_FactorizeOut :
525
+ def factorize (self , group ) -> EncodedGroups :
498
526
from xarray .core .dataarray import DataArray
499
527
500
528
data = group .data
@@ -512,11 +540,7 @@ def factorize(self, group) -> T_FactorizeOut:
512
540
full_index = binned .categories
513
541
uniques = np .sort (pd .unique (binned_codes ))
514
542
unique_values = full_index [uniques [uniques != - 1 ]]
515
- group_indices = [
516
- g for g in _codes_to_groups (binned_codes , len (full_index )) if g
517
- ]
518
-
519
- if len (group_indices ) == 0 :
543
+ if (binned_codes == - 1 ).all ():
520
544
raise ValueError (
521
545
f"None of the data falls within bins with edges { self .bins !r} "
522
546
)
@@ -525,7 +549,9 @@ def factorize(self, group) -> T_FactorizeOut:
525
549
binned_codes , getattr (group , "coords" , None ), name = new_dim_name
526
550
)
527
551
unique_coord = IndexVariable (new_dim_name , pd .Index (unique_values ), group .attrs )
528
- return codes , group_indices , unique_coord , full_index
552
+ return EncodedGroups (
553
+ codes = codes , full_index = full_index , unique_coord = unique_coord
554
+ )
529
555
530
556
531
557
@dataclass
@@ -628,7 +654,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
628
654
_apply_loffset (self .loffset , first_items )
629
655
return first_items , codes
630
656
631
- def factorize (self , group ) -> T_FactorizeOut :
657
+ def factorize (self , group ) -> EncodedGroups :
632
658
self ._init_properties (group )
633
659
full_index , first_items , codes_ = self ._get_index_and_items ()
634
660
sbins = first_items .values .astype (np .int64 )
@@ -640,7 +666,12 @@ def factorize(self, group) -> T_FactorizeOut:
640
666
unique_coord = IndexVariable (group .name , first_items .index , group .attrs )
641
667
codes = group .copy (data = codes_ )
642
668
643
- return codes , group_indices , unique_coord , full_index
669
+ return EncodedGroups (
670
+ codes = codes ,
671
+ group_indices = group_indices ,
672
+ full_index = full_index ,
673
+ unique_coord = unique_coord ,
674
+ )
644
675
645
676
646
677
def _validate_groupby_squeeze (squeeze : bool | None ) -> None :
0 commit comments