Skip to content

Commit eee5e2a

Browse files
committed
Use shuffle in groupby binary ops.
xref pydata#9546 Closes pydata#9267
1 parent 0945e0e commit eee5e2a

File tree

2 files changed

+70
-17
lines changed

2 files changed

+70
-17
lines changed

xarray/core/groupby.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from xarray.core.alignment import align, broadcast
2121
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
2222
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
23+
from xarray.core.computation import apply_ufunc
2324
from xarray.core.concat import concat
2425
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2526
from xarray.core.duck_array_ops import where
@@ -49,7 +50,7 @@
4950
peek_at,
5051
)
5152
from xarray.core.variable import IndexVariable, Variable
52-
from xarray.namedarray.pycompat import is_chunked_array
53+
from xarray.namedarray.pycompat import is_chunked_array, is_duck_dask_array
5354

5455
if TYPE_CHECKING:
5556
from numpy.typing import ArrayLike
@@ -899,25 +900,75 @@ def _binary_op(self, other, f, reflexive=False):
899900
group = group.where(~mask, drop=True)
900901
codes = codes.where(~mask, drop=True).astype(int)
901902

902-
# if other is dask-backed, that's a hint that the
903-
# "expanded" dataset is too big to hold in memory.
904-
# this can be the case when `other` was read from disk
905-
# and contains our lazy indexing classes
906-
# We need to check for dask-backed Datasets
907-
# so utils.is_duck_dask_array does not work for this check
908-
if obj.chunks and not other.chunks:
909-
# TODO: What about datasets with some dask vars, and others not?
910-
# This handles dims other than `name``
911-
chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims}
912-
# a chunk size of 1 seems reasonable since we expect individual elements of
913-
# other to be repeated multiple times across the reduced dimension(s)
914-
chunks[name] = 1
915-
other = other.chunk(chunks)
903+
def _vindex_wrapper(array, idxr, like):
904+
# we want to use the fact that we know the chunksizes for the output (matches obj)
905+
# so we can't just use Variable's indexing
906+
import dask
907+
from dask.array.core import slices_from_chunks
908+
from dask.graph_manipulation import clone
909+
910+
array = clone(array) # FIXME: add to dask
911+
912+
assert array.ndim == 1
913+
to_shape = like.shape[-1:]
914+
to_chunks = like.chunks[-1:]
915+
flat_indices = [
916+
idxr[slicer].ravel().tolist()
917+
for slicer in slices_from_chunks(to_chunks)
918+
]
919+
# FIXME: figure out axis
920+
shuffled = dask.array.shuffle(
921+
array, flat_indices, axis=array.ndim - 1, chunks="auto"
922+
)
923+
if shuffled.shape != to_shape:
924+
return dask.array.reshape_blockwise(
925+
shuffled, shape=to_shape, chunks=to_chunks
926+
)
927+
else:
928+
return shuffled
916929

917930
# codes are defined for coord, so we align `other` with `coord`
918931
# before indexing
919932
other, _ = align(other, coord, join="right", copy=False)
920-
expanded = other.isel({name: codes})
933+
934+
other_as_dataset = (
935+
other._to_temp_dataset() if isinstance(other, DataArray) else other
936+
)
937+
obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj
938+
dask_vars = []
939+
non_dask_vars = []
940+
for varname, var in other_as_dataset._variables.items():
941+
if is_duck_dask_array(var._data):
942+
dask_vars.append(varname)
943+
else:
944+
non_dask_vars.append(varname)
945+
expanded = other_as_dataset[non_dask_vars].isel({name: codes})
946+
if dask_vars:
947+
other_dims = other_as_dataset[dask_vars].dims
948+
obj_dims = obj_as_dataset[dask_vars].dims
949+
expanded = expanded.merge(
950+
apply_ufunc(
951+
_vindex_wrapper,
952+
other_as_dataset[dask_vars],
953+
codes,
954+
obj_as_dataset[dask_vars],
955+
input_core_dims=[
956+
tuple(other_dims), # FIXME: ..., name
957+
tuple(codes.dims),
958+
tuple(obj_dims),
959+
],
960+
# When other is the result of a reduction over Ellipsis
961+
# obj.dims is a superset of other.dims, and contains
962+
# dims not present in the output
963+
exclude_dims=set(obj_dims) - set(other_dims),
964+
output_core_dims=[tuple(codes.dims)],
965+
dask="allowed",
966+
join=OPTIONS["arithmetic_join"],
967+
)
968+
)
969+
970+
if isinstance(other, DataArray):
971+
expanded = other._from_temp_dataset(expanded)
921972

922973
result = g(obj, expanded)
923974

xarray/tests/test_groupby.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2654,12 +2654,14 @@ def test_groupby_math_auto_chunk() -> None:
26542654
dims=("y", "x"),
26552655
coords={"label": ("x", [2, 2, 1])},
26562656
)
2657+
# da.groupby("label").min(...)
26572658
sub = xr.DataArray(
26582659
InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]}
26592660
)
26602661
chunked = da.chunk(x=1, y=2)
26612662
chunked.label.load()
2662-
actual = chunked.groupby("label") - sub
2663+
with raise_if_dask_computes():
2664+
actual = chunked.groupby("label") - sub
26632665
assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}
26642666

26652667

0 commit comments

Comments
 (0)