Skip to content

Commit 6cfe9fa

Browse files
committed
generalized chunk_hint function inside indexing
1 parent 8bbc141 commit 6cfe9fa

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

xarray/core/indexing.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from xarray.core import duck_array_ops
1818
from xarray.core.nputils import NumpyVIndexAdapter
1919
from xarray.core.options import OPTIONS
20+
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
2021
from xarray.core.pycompat import array_type, integer_types, is_duck_dask_array
2122
from xarray.core.types import T_Xarray
2223
from xarray.core.utils import (
@@ -1075,16 +1076,15 @@ def _arrayize_vectorized_indexer(indexer, shape):
10751076
return VectorizedIndexer(tuple(new_key))
10761077

10771078

1078-
def _dask_array_with_chunks_hint(array, chunks):
1079-
"""Create a dask array using the chunks hint for dimensions of size > 1."""
1080-
import dask.array as da
1079+
def _chunked_array_with_chunks_hint(array, chunks, chunkmanager):
1080+
"""Create a chunked array using the chunks hint for dimensions of size > 1."""
10811081

10821082
if len(chunks) < array.ndim:
10831083
raise ValueError("not enough chunks in hint")
10841084
new_chunks = []
10851085
for chunk, size in zip(chunks, array.shape):
10861086
new_chunks.append(chunk if size > 1 else (1,))
1087-
return da.from_array(array, new_chunks)
1087+
return chunkmanager.from_array(array, new_chunks)
10881088

10891089

10901090
def _logical_any(args):
@@ -1098,8 +1098,11 @@ def _masked_result_drop_slice(key, data=None):
10981098
new_keys = []
10991099
for k in key:
11001100
if isinstance(k, np.ndarray):
1101-
if is_duck_dask_array(data):
1102-
new_keys.append(_dask_array_with_chunks_hint(k, chunks_hint))
1101+
if is_chunked_array(data):
1102+
chunkmanager = get_chunked_array_type(data)
1103+
new_keys.append(
1104+
_chunked_array_with_chunks_hint(k, chunks_hint, chunkmanager)
1105+
)
11031106
elif isinstance(data, array_type("sparse")):
11041107
import sparse
11051108

xarray/core/parallelcompat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111

12-
from xarray.core import indexing, utils
12+
from xarray.core import utils
1313
from xarray.core.pycompat import DuckArrayModule, is_chunked_array, is_duck_dask_array
1414
from xarray.core.types import T_Chunks
1515

@@ -197,6 +197,8 @@ def chunks(self, data: T_DaskArray) -> T_Chunks:
197197
def from_array(self, data: np.ndarray, chunks, **kwargs) -> T_DaskArray:
198198
import dask.array as da
199199

200+
from xarray.core import indexing
201+
200202
# dask-specific kwargs
201203
name = kwargs.pop("name", None)
202204
lock = kwargs.pop("lock", False)

0 commit comments

Comments
 (0)