Skip to content

Commit bc4271c

Browse files
committed
Attempt to fix indexing for Dask
This is a naive attempt to make `isel` work with Dask Known limitation: it triggers the computation.
1 parent dbc02d4 commit bc4271c

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

xarray/core/indexing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
is_duck_dask_array,
1919
sparse_array_type,
2020
)
21-
from .utils import maybe_cast_to_coords_dtype
21+
from .utils import is_duck_array, maybe_cast_to_coords_dtype
2222

2323

2424
def expanded_indexer(key, ndim):
@@ -307,7 +307,7 @@ def __init__(self, key):
307307
for k in key:
308308
if isinstance(k, slice):
309309
k = as_integer_slice(k)
310-
elif isinstance(k, np.ndarray):
310+
elif is_duck_array(k):
311311
if not np.issubdtype(k.dtype, np.integer):
312312
raise TypeError(
313313
f"invalid indexer array, does not have integer dtype: {k!r}"
@@ -320,7 +320,7 @@ def __init__(self, key):
320320
"invalid indexer key: ndarray arguments "
321321
f"have different numbers of dimensions: {ndims}"
322322
)
323-
k = np.asarray(k, dtype=np.int64)
323+
k = k.astype(np.int64)
324324
else:
325325
raise TypeError(
326326
f"unexpected indexer type for {type(self).__name__}: {k!r}"

xarray/tests/test_indexing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from . import IndexerMaker, ReturnItem, assert_array_equal
1111

12+
da = pytest.importorskip("dask.array")
13+
1214
B = IndexerMaker(indexing.BasicIndexer)
1315

1416

@@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None:
729731
expected = DataArray(expected_data)
730732

731733
assert [actual.data.item()] == [expected.data.item()]
734+
735+
736+
def test_indexing_dask_array():
737+
da = DataArray(
738+
np.ones(10 * 3 * 3).reshape((10, 3, 3)),
739+
dims=("time", "x", "y"),
740+
).chunk(dict(time=-1, x=1, y=1))
741+
da[{"time": 9}] = 42
742+
743+
idx = da.argmax("time")
744+
actual = da.isel(time=idx)
745+
746+
assert np.all(actual == 42)

0 commit comments

Comments
 (0)