|
20 | 20 | from xarray.core.alignment import align, broadcast
|
21 | 21 | from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
|
22 | 22 | from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
|
| 23 | +from xarray.core.computation import apply_ufunc |
23 | 24 | from xarray.core.concat import concat
|
24 | 25 | from xarray.core.coordinates import Coordinates, _coordinates_from_variable
|
25 | 26 | from xarray.core.duck_array_ops import where
|
|
49 | 50 | peek_at,
|
50 | 51 | )
|
51 | 52 | from xarray.core.variable import IndexVariable, Variable
|
52 |
| -from xarray.namedarray.pycompat import is_chunked_array |
| 53 | +from xarray.namedarray.pycompat import is_chunked_array, is_duck_dask_array |
53 | 54 |
|
54 | 55 | if TYPE_CHECKING:
|
55 | 56 | from numpy.typing import ArrayLike
|
@@ -899,25 +900,75 @@ def _binary_op(self, other, f, reflexive=False):
|
899 | 900 | group = group.where(~mask, drop=True)
|
900 | 901 | codes = codes.where(~mask, drop=True).astype(int)
|
901 | 902 |
|
902 |
| - # if other is dask-backed, that's a hint that the |
903 |
| - # "expanded" dataset is too big to hold in memory. |
904 |
| - # this can be the case when `other` was read from disk |
905 |
| - # and contains our lazy indexing classes |
906 |
| - # We need to check for dask-backed Datasets |
907 |
| - # so utils.is_duck_dask_array does not work for this check |
908 |
| - if obj.chunks and not other.chunks: |
909 |
| - # TODO: What about datasets with some dask vars, and others not? |
910 |
| - # This handles dims other than `name`` |
911 |
| - chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims} |
912 |
| - # a chunk size of 1 seems reasonable since we expect individual elements of |
913 |
| - # other to be repeated multiple times across the reduced dimension(s) |
914 |
| - chunks[name] = 1 |
915 |
| - other = other.chunk(chunks) |
| 903 | + def _vindex_wrapper(array, idxr, like): |
| 904 | + # we want to use the fact that we know the chunksizes for the output (matches obj) |
| 905 | + # so we can't just use Variable's indexing |
| 906 | + import dask |
| 907 | + from dask.array.core import slices_from_chunks |
| 908 | + from dask.graph_manipulation import clone |
| 909 | + |
| 910 | + array = clone(array) # FIXME: add to dask |
| 911 | + |
| 912 | + assert array.ndim == 1 |
| 913 | + to_shape = like.shape[-1:] |
| 914 | + to_chunks = like.chunks[-1:] |
| 915 | + flat_indices = [ |
| 916 | + idxr[slicer].ravel().tolist() |
| 917 | + for slicer in slices_from_chunks(to_chunks) |
| 918 | + ] |
| 919 | + # FIXME: figure out axis |
| 920 | + shuffled = dask.array.shuffle( |
| 921 | + array, flat_indices, axis=array.ndim - 1, chunks="auto" |
| 922 | + ) |
| 923 | + if shuffled.shape != to_shape: |
| 924 | + return dask.array.reshape_blockwise( |
| 925 | + shuffled, shape=to_shape, chunks=to_chunks |
| 926 | + ) |
| 927 | + else: |
| 928 | + return shuffled |
916 | 929 |
|
917 | 930 | # codes are defined for coord, so we align `other` with `coord`
|
918 | 931 | # before indexing
|
919 | 932 | other, _ = align(other, coord, join="right", copy=False)
|
920 |
| - expanded = other.isel({name: codes}) |
| 933 | + |
| 934 | + other_as_dataset = ( |
| 935 | + other._to_temp_dataset() if isinstance(other, DataArray) else other |
| 936 | + ) |
| 937 | + obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj |
| 938 | + dask_vars = [] |
| 939 | + non_dask_vars = [] |
| 940 | + for varname, var in other_as_dataset._variables.items(): |
| 941 | + if is_duck_dask_array(var._data): |
| 942 | + dask_vars.append(varname) |
| 943 | + else: |
| 944 | + non_dask_vars.append(varname) |
| 945 | + expanded = other_as_dataset[non_dask_vars].isel({name: codes}) |
| 946 | + if dask_vars: |
| 947 | + other_dims = other_as_dataset[dask_vars].dims |
| 948 | + obj_dims = obj_as_dataset[dask_vars].dims |
| 949 | + expanded = expanded.merge( |
| 950 | + apply_ufunc( |
| 951 | + _vindex_wrapper, |
| 952 | + other_as_dataset[dask_vars], |
| 953 | + codes, |
| 954 | + obj_as_dataset[dask_vars], |
| 955 | + input_core_dims=[ |
| 956 | + tuple(other_dims), # FIXME: ..., name |
| 957 | + tuple(codes.dims), |
| 958 | + tuple(obj_dims), |
| 959 | + ], |
| 960 | + # When other is the result of a reduction over Ellipsis |
| 961 | + # obj.dims is a superset of other.dims, and contains |
| 962 | + # dims not present in the output |
| 963 | + exclude_dims=set(obj_dims) - set(other_dims), |
| 964 | + output_core_dims=[tuple(codes.dims)], |
| 965 | + dask="allowed", |
| 966 | + join=OPTIONS["arithmetic_join"], |
| 967 | + ) |
| 968 | + ) |
| 969 | + |
| 970 | + if isinstance(other, DataArray): |
| 971 | + expanded = other._from_temp_dataset(expanded) |
921 | 972 |
|
922 | 973 | result = g(obj, expanded)
|
923 | 974 |
|
|
0 commit comments