Skip to content

Commit 270873a

Browse files
committed
Attempt to fix indexing for Dask
This is a naive attempt to make `isel` work with Dask Known limitation: it triggers the computation.
1 parent 960010b commit 270873a

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

xarray/core/indexing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datetime import timedelta
77
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
88

9+
import dask.array as da
910
import numpy as np
1011
import pandas as pd
1112

@@ -18,7 +19,7 @@
1819
is_duck_dask_array,
1920
sparse_array_type,
2021
)
21-
from .utils import maybe_cast_to_coords_dtype
22+
from .utils import is_duck_array, maybe_cast_to_coords_dtype
2223

2324

2425
def expanded_indexer(key, ndim):
@@ -307,7 +308,7 @@ def __init__(self, key):
307308
for k in key:
308309
if isinstance(k, slice):
309310
k = as_integer_slice(k)
310-
elif isinstance(k, np.ndarray):
311+
elif is_duck_array(k):
311312
if not np.issubdtype(k.dtype, np.integer):
312313
raise TypeError(
313314
f"invalid indexer array, does not have integer dtype: {k!r}"
@@ -320,7 +321,7 @@ def __init__(self, key):
320321
"invalid indexer key: ndarray arguments "
321322
f"have different numbers of dimensions: {ndims}"
322323
)
323-
k = np.asarray(k, dtype=np.int64)
324+
k = k.astype(np.int64)
324325
else:
325326
raise TypeError(
326327
f"unexpected indexer type for {type(self).__name__}: {k!r}"
@@ -973,7 +974,6 @@ def _arrayize_vectorized_indexer(indexer, shape):
973974

974975
def _dask_array_with_chunks_hint(array, chunks):
975976
"""Create a dask array using the chunks hint for dimensions of size > 1."""
976-
import dask.array as da
977977

978978
if len(chunks) < array.ndim:
979979
raise ValueError("not enough chunks in hint")

xarray/tests/test_indexing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from . import IndexerMaker, ReturnItem, assert_array_equal
1111

12+
da = pytest.importorskip("dask.array")
13+
1214
B = IndexerMaker(indexing.BasicIndexer)
1315

1416

@@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None:
729731
expected = DataArray(expected_data)
730732

731733
assert [actual.data.item()] == [expected.data.item()]
734+
735+
736+
def test_indexing_dask_array():
737+
da = DataArray(
738+
np.ones(10 * 3 * 3).reshape((10, 3, 3)),
739+
dims=("time", "x", "y"),
740+
).chunk(dict(time=-1, x=1, y=1))
741+
da[{"time": 9}] = 42
742+
743+
idx = da.argmax("time")
744+
actual = da.isel(time=idx)
745+
746+
assert np.all(actual == 42)

0 commit comments

Comments
 (0)