44
44
from xarray .core .utils import Frozen
45
45
46
46
GroupKey = Any
47
+ GroupIndex = int | slice | list [int ]
47
48
48
49
49
50
def check_reduce_dims (reduce_dims , dimensions ):
@@ -84,7 +85,7 @@ def unique_value_groups(
84
85
return values , groups , inverse
85
86
86
87
87
- def _codes_to_groups (inverse , N ):
88
+ def _codes_to_groups (inverse , N ) -> list [ list [ int ]] :
88
89
groups : list [list [int ]] = [[] for _ in range (N )]
89
90
for n , g in enumerate (inverse ):
90
91
if g >= 0 :
@@ -126,11 +127,11 @@ def _dummy_copy(xarray_obj):
126
127
return res
127
128
128
129
129
- def _is_one_or_none (obj ):
130
+ def _is_one_or_none (obj ) -> bool :
130
131
return obj == 1 or obj is None
131
132
132
133
133
- def _consolidate_slices (slices ) :
134
+ def _consolidate_slices (slices : list [ slice ]) -> list [ slice ] :
134
135
"""Consolidate adjacent slices in a list of slices."""
135
136
result = []
136
137
last_slice = slice (None )
@@ -188,7 +189,6 @@ def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None:
188
189
self .name = name
189
190
self .coords = coords
190
191
self .size = obj .sizes [name ]
191
- self .dataarray = obj [name ]
192
192
193
193
@property
194
194
def dims (self ) -> tuple [Hashable ]:
@@ -222,6 +222,13 @@ def __getitem__(self, key):
222
222
key = key [0 ]
223
223
return self .values [key ]
224
224
225
+ def as_dataarray (self ) -> DataArray :
226
+ from xarray .core .dataarray import DataArray
227
+
228
+ return DataArray (
229
+ data = self .data , dims = (self .name ,), coords = self .coords , name = self .name
230
+ )
231
+
225
232
226
233
T_Group = TypeVar ("T_Group" , bound = Union ["DataArray" , "IndexVariable" , _DummyGroup ])
227
234
@@ -288,14 +295,16 @@ def _apply_loffset(
288
295
289
296
290
297
class Grouper :
291
- def __init__ (self , group : T_Group ):
292
- self .group : T_Group | None = group
293
- self . codes : np . ndarry | None = None
298
+ def __init__ (self , group : T_Group | Hashable ):
299
+ self .group : T_Group | Hashable = group
300
+
294
301
self .labels = None
295
- self .group_indices : list [list [int , ...]] | None = None
296
- self .unique_coord = None
297
- self .full_index : pd .Index | None = None
298
- self ._group_as_index = None
302
+ self ._group_as_index : pd .Index | None = None
303
+
304
+ self .codes : DataArray
305
+ self .group_indices : list [int ] | list [slice ] | list [list [int ]]
306
+ self .unique_coord : IndexVariable | _DummyGroup
307
+ self .full_index : pd .Index
299
308
300
309
@property
301
310
def name (self ) -> Hashable :
@@ -328,10 +337,9 @@ def group_as_index(self) -> pd.Index:
328
337
self ._group_as_index = safe_cast_to_index (self .group1d )
329
338
return self ._group_as_index
330
339
331
- def _resolve_group (self , obj : T_DataArray | T_Dataset ) -> None :
340
+ def _resolve_group (self , obj : T_Xarray ) :
332
341
from xarray .core .dataarray import DataArray
333
342
334
- group : T_Group
335
343
group = self .group
336
344
if not isinstance (group , (DataArray , IndexVariable )):
337
345
if not hashable (group ):
@@ -340,15 +348,14 @@ def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None:
340
348
"name of an xarray variable or dimension. "
341
349
f"Received { group !r} instead."
342
350
)
343
- group_da : T_DataArray = obj [group ]
344
- if len (group_da ) == 0 :
345
- raise ValueError (f"{ group_da .name } must not be empty" )
346
-
347
- if group_da .name not in obj .coords and group_da .name in obj .dims :
351
+ group = obj [group ]
352
+ if len (group ) == 0 :
353
+ raise ValueError (f"{ group .name } must not be empty" )
354
+ if group .name not in obj ._indexes and group .name in obj .dims :
348
355
# DummyGroups should not appear on groupby results
349
356
group = _DummyGroup (obj , group .name , group .coords )
350
357
351
- if getattr (group , "name" , None ) is None :
358
+ elif getattr (group , "name" , None ) is None :
352
359
group .name = "group"
353
360
354
361
self .group = group
@@ -402,10 +409,10 @@ def _factorize_dummy(self, squeeze) -> None:
402
409
# equivalent to: group_indices = group_indices.reshape(-1, 1)
403
410
self .group_indices = [slice (i , i + 1 ) for i in range (size )]
404
411
else :
405
- self .group_indices = np . arange ( size )
412
+ self .group_indices = list ( range ( size ) )
406
413
codes = np .arange (size )
407
414
if isinstance (self .group , _DummyGroup ):
408
- self .codes = self .group .dataarray .copy (data = codes )
415
+ self .codes = self .group .as_dataarray () .copy (data = codes )
409
416
else :
410
417
self .codes = self .group .copy (data = codes )
411
418
self .unique_coord = self .group
@@ -483,7 +490,7 @@ def __init__(
483
490
raise ValueError ("index must be monotonic for resampling" )
484
491
485
492
if isinstance (group_as_index , CFTimeIndex ):
486
- self . grouper = CFTimeGrouper (
493
+ grouper = CFTimeGrouper (
487
494
freq = self .freq ,
488
495
closed = self .closed ,
489
496
label = self .label ,
@@ -492,15 +499,16 @@ def __init__(
492
499
loffset = self .loffset ,
493
500
)
494
501
else :
495
- self . grouper = pd .Grouper (
502
+ grouper = pd .Grouper (
496
503
freq = self .freq ,
497
504
closed = self .closed ,
498
505
label = self .label ,
499
506
origin = self .origin ,
500
507
offset = self .offset ,
501
508
)
509
+ self .grouper : CFTimeGrouper | pd .Grouper = grouper
502
510
503
- def _get_index_and_items (self ):
511
+ def _get_index_and_items (self ) -> tuple [ pd . Index , pd . Series , np . ndarray ] :
504
512
first_items , codes = self .first_items ()
505
513
full_index = first_items .index
506
514
if first_items .isnull ().any ():
@@ -509,7 +517,7 @@ def _get_index_and_items(self):
509
517
full_index = full_index .rename ("__resample_dim__" )
510
518
return full_index , first_items , codes
511
519
512
- def first_items (self ):
520
+ def first_items (self ) -> tuple [ pd . Series , np . ndarray ] :
513
521
from xarray import CFTimeIndex
514
522
515
523
if isinstance (self .group_as_index , CFTimeIndex ):
@@ -664,7 +672,7 @@ def reduce(
664
672
raise NotImplementedError ()
665
673
666
674
@property
667
- def groups (self ) -> dict [GroupKey , slice | int | list [ int ] ]:
675
+ def groups (self ) -> dict [GroupKey , GroupIndex ]:
668
676
"""
669
677
Mapping from group labels to indices. The indices can be used to index the underlying object.
670
678
"""
@@ -729,7 +737,7 @@ def _binary_op(self, other, f, reflexive=False):
729
737
dims = group .dims
730
738
731
739
if isinstance (group , _DummyGroup ):
732
- group = coord = group .dataarray
740
+ group = coord = group .as_dataarray ()
733
741
else :
734
742
coord = grouper .unique_coord
735
743
if not isinstance (coord , DataArray ):
0 commit comments