Skip to content

Commit 6fba398

Browse files
committed
Cleanup
1 parent 93ef584 commit 6fba398

File tree

1 file changed

+35
-27
lines changed

1 file changed

+35
-27
lines changed

xarray/core/groupby.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from xarray.core.utils import Frozen
4545

4646
GroupKey = Any
47+
GroupIndex = int | slice | list[int]
4748

4849

4950
def check_reduce_dims(reduce_dims, dimensions):
@@ -84,7 +85,7 @@ def unique_value_groups(
8485
return values, groups, inverse
8586

8687

87-
def _codes_to_groups(inverse, N):
88+
def _codes_to_groups(inverse, N) -> list[list[int]]:
8889
groups: list[list[int]] = [[] for _ in range(N)]
8990
for n, g in enumerate(inverse):
9091
if g >= 0:
@@ -126,11 +127,11 @@ def _dummy_copy(xarray_obj):
126127
return res
127128

128129

129-
def _is_one_or_none(obj):
130+
def _is_one_or_none(obj) -> bool:
130131
return obj == 1 or obj is None
131132

132133

133-
def _consolidate_slices(slices):
134+
def _consolidate_slices(slices: list[slice]) -> list[slice]:
134135
"""Consolidate adjacent slices in a list of slices."""
135136
result = []
136137
last_slice = slice(None)
@@ -188,7 +189,6 @@ def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None:
188189
self.name = name
189190
self.coords = coords
190191
self.size = obj.sizes[name]
191-
self.dataarray = obj[name]
192192

193193
@property
194194
def dims(self) -> tuple[Hashable]:
@@ -222,6 +222,13 @@ def __getitem__(self, key):
222222
key = key[0]
223223
return self.values[key]
224224

225+
def as_dataarray(self) -> DataArray:
226+
from xarray.core.dataarray import DataArray
227+
228+
return DataArray(
229+
data=self.data, dims=(self.name,), coords=self.coords, name=self.name
230+
)
231+
225232

226233
T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup])
227234

@@ -288,14 +295,16 @@ def _apply_loffset(
288295

289296

290297
class Grouper:
291-
def __init__(self, group: T_Group):
292-
self.group : T_Group | None = group
293-
self.codes : np.ndarry | None = None
298+
def __init__(self, group: T_Group | Hashable):
299+
self.group: T_Group | Hashable = group
300+
294301
self.labels = None
295-
self.group_indices : list[list[int, ...]] | None= None
296-
self.unique_coord = None
297-
self.full_index : pd.Index | None = None
298-
self._group_as_index = None
302+
self._group_as_index: pd.Index | None = None
303+
304+
self.codes: DataArray
305+
self.group_indices: list[int] | list[slice] | list[list[int]]
306+
self.unique_coord: IndexVariable | _DummyGroup
307+
self.full_index: pd.Index
299308

300309
@property
301310
def name(self) -> Hashable:
@@ -328,10 +337,9 @@ def group_as_index(self) -> pd.Index:
328337
self._group_as_index = safe_cast_to_index(self.group1d)
329338
return self._group_as_index
330339

331-
def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None:
340+
def _resolve_group(self, obj: T_Xarray):
332341
from xarray.core.dataarray import DataArray
333342

334-
group: T_Group
335343
group = self.group
336344
if not isinstance(group, (DataArray, IndexVariable)):
337345
if not hashable(group):
@@ -340,15 +348,14 @@ def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None:
340348
"name of an xarray variable or dimension. "
341349
f"Received {group!r} instead."
342350
)
343-
group_da : T_DataArray = obj[group]
344-
if len(group_da) == 0:
345-
raise ValueError(f"{group_da.name} must not be empty")
346-
347-
if group_da.name not in obj.coords and group_da.name in obj.dims:
351+
group = obj[group]
352+
if len(group) == 0:
353+
raise ValueError(f"{group.name} must not be empty")
354+
if group.name not in obj._indexes and group.name in obj.dims:
348355
# DummyGroups should not appear on groupby results
349356
group = _DummyGroup(obj, group.name, group.coords)
350357

351-
if getattr(group, "name", None) is None:
358+
elif getattr(group, "name", None) is None:
352359
group.name = "group"
353360

354361
self.group = group
@@ -402,10 +409,10 @@ def _factorize_dummy(self, squeeze) -> None:
402409
# equivalent to: group_indices = group_indices.reshape(-1, 1)
403410
self.group_indices = [slice(i, i + 1) for i in range(size)]
404411
else:
405-
self.group_indices = np.arange(size)
412+
self.group_indices = list(range(size))
406413
codes = np.arange(size)
407414
if isinstance(self.group, _DummyGroup):
408-
self.codes = self.group.dataarray.copy(data=codes)
415+
self.codes = self.group.as_dataarray().copy(data=codes)
409416
else:
410417
self.codes = self.group.copy(data=codes)
411418
self.unique_coord = self.group
@@ -483,7 +490,7 @@ def __init__(
483490
raise ValueError("index must be monotonic for resampling")
484491

485492
if isinstance(group_as_index, CFTimeIndex):
486-
self.grouper = CFTimeGrouper(
493+
grouper = CFTimeGrouper(
487494
freq=self.freq,
488495
closed=self.closed,
489496
label=self.label,
@@ -492,15 +499,16 @@ def __init__(
492499
loffset=self.loffset,
493500
)
494501
else:
495-
self.grouper = pd.Grouper(
502+
grouper = pd.Grouper(
496503
freq=self.freq,
497504
closed=self.closed,
498505
label=self.label,
499506
origin=self.origin,
500507
offset=self.offset,
501508
)
509+
self.grouper: CFTimeGrouper | pd.Grouper = grouper
502510

503-
def _get_index_and_items(self):
511+
def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]:
504512
first_items, codes = self.first_items()
505513
full_index = first_items.index
506514
if first_items.isnull().any():
@@ -509,7 +517,7 @@ def _get_index_and_items(self):
509517
full_index = full_index.rename("__resample_dim__")
510518
return full_index, first_items, codes
511519

512-
def first_items(self):
520+
def first_items(self) -> tuple[pd.Series, np.ndarray]:
513521
from xarray import CFTimeIndex
514522

515523
if isinstance(self.group_as_index, CFTimeIndex):
@@ -664,7 +672,7 @@ def reduce(
664672
raise NotImplementedError()
665673

666674
@property
667-
def groups(self) -> dict[GroupKey, slice | int | list[int]]:
675+
def groups(self) -> dict[GroupKey, GroupIndex]:
668676
"""
669677
Mapping from group labels to indices. The indices can be used to index the underlying object.
670678
"""
@@ -729,7 +737,7 @@ def _binary_op(self, other, f, reflexive=False):
729737
dims = group.dims
730738

731739
if isinstance(group, _DummyGroup):
732-
group = coord = group.dataarray
740+
group = coord = group.as_dataarray()
733741
else:
734742
coord = grouper.unique_coord
735743
if not isinstance(coord, DataArray):

0 commit comments

Comments
 (0)