Skip to content

Commit 24f9292

Browse files
dcheriancrusaderky
andauthored
Make dask names change when chunking Variables by different amounts. (#3584)
* Make dask tokens change when chunking Variables by different amounts. When rechunking by the current chunk size, the dask token should not change. Add a __dask_tokenize__ method for ReprObject so that this behaviour is present when DataArrays are converted to temporary Datasets and back. Co-Authored-By: crusaderky <[email protected]> Co-authored-by: crusaderky <[email protected]>
1 parent ef6e6a7 commit 24f9292

File tree

6 files changed

+46
-11
lines changed

6 files changed

+46
-11
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ New Features
4747

4848
Bug fixes
4949
~~~~~~~~~
50+
5051
- Fix :py:meth:`xarray.combine_by_coords` to allow for combining incomplete
5152
hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger
5253
<https://github.com/bolliger32>`_.
@@ -91,6 +92,9 @@ Documentation
9192

9293
Internal Changes
9394
~~~~~~~~~~~~~~~~
95+
- Make sure dask names change when rechunking by different chunk sizes. Conversely, make sure they
96+
stay the same when rechunking by the same chunk size. (:issue:`3350`)
97+
By `Deepak Cherian <https://github.com/dcherian>`_.
9498
- 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`,
9599
:py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int,
96100
slice, list of int, scalar ndarray, or 1-dimensional ndarray.

xarray/core/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1754,7 +1754,10 @@ def maybe_chunk(name, var, chunks):
17541754
if not chunks:
17551755
chunks = None
17561756
if var.ndim > 0:
1757-
token2 = tokenize(name, token if token else var._data)
1757+
# when rechunking by different amounts, make sure dask names change
1758+
# by provinding chunks as an input to tokenize.
1759+
# subtle bugs result otherwise. see GH3350
1760+
token2 = tokenize(name, token if token else var._data, chunks)
17581761
name2 = f"{name_prefix}{name}-{token2}"
17591762
return var.chunk(chunks, name=name2, lock=lock)
17601763
else:

xarray/core/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,12 @@ def __eq__(self, other) -> bool:
547547
return False
548548

549549
def __hash__(self) -> int:
550-
return hash((ReprObject, self._value))
550+
return hash((type(self), self._value))
551+
552+
def __dask_tokenize__(self):
553+
from dask.base import normalize_token
554+
555+
return normalize_token((type(self), self._value))
551556

552557

553558
@contextlib.contextmanager

xarray/tests/test_dask.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def func(obj):
10831083
actual = xr.map_blocks(func, obj)
10841084
expected = func(obj)
10851085
assert_chunks_equal(expected.chunk(), actual)
1086-
xr.testing.assert_identical(actual.compute(), expected.compute())
1086+
assert_identical(actual, expected)
10871087

10881088

10891089
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
@@ -1092,7 +1092,7 @@ def test_map_blocks_convert_args_to_list(obj):
10921092
with raise_if_dask_computes():
10931093
actual = xr.map_blocks(operator.add, obj, [10])
10941094
assert_chunks_equal(expected.chunk(), actual)
1095-
xr.testing.assert_identical(actual.compute(), expected.compute())
1095+
assert_identical(actual, expected)
10961096

10971097

10981098
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
@@ -1107,7 +1107,7 @@ def add_attrs(obj):
11071107
with raise_if_dask_computes():
11081108
actual = xr.map_blocks(add_attrs, obj)
11091109

1110-
xr.testing.assert_identical(actual.compute(), expected.compute())
1110+
assert_identical(actual, expected)
11111111

11121112

11131113
def test_map_blocks_change_name(map_da):
@@ -1120,7 +1120,7 @@ def change_name(obj):
11201120
with raise_if_dask_computes():
11211121
actual = xr.map_blocks(change_name, map_da)
11221122

1123-
xr.testing.assert_identical(actual.compute(), expected.compute())
1123+
assert_identical(actual, expected)
11241124

11251125

11261126
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
@@ -1129,15 +1129,15 @@ def test_map_blocks_kwargs(obj):
11291129
with raise_if_dask_computes():
11301130
actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan))
11311131
assert_chunks_equal(expected.chunk(), actual)
1132-
xr.testing.assert_identical(actual.compute(), expected.compute())
1132+
assert_identical(actual, expected)
11331133

11341134

11351135
def test_map_blocks_to_array(map_ds):
11361136
with raise_if_dask_computes():
11371137
actual = xr.map_blocks(lambda x: x.to_array(), map_ds)
11381138

11391139
# to_array does not preserve name, so cannot use assert_identical
1140-
assert_equal(actual.compute(), map_ds.to_array().compute())
1140+
assert_equal(actual, map_ds.to_array())
11411141

11421142

11431143
@pytest.mark.parametrize(
@@ -1156,7 +1156,7 @@ def test_map_blocks_da_transformations(func, map_da):
11561156
with raise_if_dask_computes():
11571157
actual = xr.map_blocks(func, map_da)
11581158

1159-
assert_identical(actual.compute(), func(map_da).compute())
1159+
assert_identical(actual, func(map_da))
11601160

11611161

11621162
@pytest.mark.parametrize(
@@ -1175,7 +1175,7 @@ def test_map_blocks_ds_transformations(func, map_ds):
11751175
with raise_if_dask_computes():
11761176
actual = xr.map_blocks(func, map_ds)
11771177

1178-
assert_identical(actual.compute(), func(map_ds).compute())
1178+
assert_identical(actual, func(map_ds))
11791179

11801180

11811181
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
@@ -1188,7 +1188,7 @@ def func(obj):
11881188
expected = xr.map_blocks(func, obj)
11891189
actual = obj.map_blocks(func)
11901190

1191-
assert_identical(expected.compute(), actual.compute())
1191+
assert_identical(expected, actual)
11921192

11931193

11941194
def test_map_blocks_hlg_layers():

xarray/tests/test_dataarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,12 +752,19 @@ def test_chunk(self):
752752

753753
blocked = unblocked.chunk()
754754
assert blocked.chunks == ((3,), (4,))
755+
first_dask_name = blocked.data.name
755756

756757
blocked = unblocked.chunk(chunks=((2, 1), (2, 2)))
757758
assert blocked.chunks == ((2, 1), (2, 2))
759+
assert blocked.data.name != first_dask_name
758760

759761
blocked = unblocked.chunk(chunks=(3, 3))
760762
assert blocked.chunks == ((3,), (3, 1))
763+
assert blocked.data.name != first_dask_name
764+
765+
# name doesn't change when rechunking by same amount
766+
# this fails if ReprObject doesn't have __dask_tokenize__ defined
767+
assert unblocked.chunk(2).data.name == unblocked.chunk(2).data.name
761768

762769
assert blocked.load().chunks is None
763770

xarray/tests/test_dataset.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,19 +936,35 @@ def test_chunk(self):
936936
expected_chunks = {"dim1": (8,), "dim2": (9,), "dim3": (10,)}
937937
assert reblocked.chunks == expected_chunks
938938

939+
def get_dask_names(ds):
940+
return {k: v.data.name for k, v in ds.items()}
941+
942+
orig_dask_names = get_dask_names(reblocked)
943+
939944
reblocked = data.chunk({"time": 5, "dim1": 5, "dim2": 5, "dim3": 5})
940945
# time is not a dim in any of the data_vars, so it
941946
# doesn't get chunked
942947
expected_chunks = {"dim1": (5, 3), "dim2": (5, 4), "dim3": (5, 5)}
943948
assert reblocked.chunks == expected_chunks
944949

950+
# make sure dask names change when rechunking by different amounts
951+
# regression test for GH3350
952+
new_dask_names = get_dask_names(reblocked)
953+
for k, v in new_dask_names.items():
954+
assert v != orig_dask_names[k]
955+
945956
reblocked = data.chunk(expected_chunks)
946957
assert reblocked.chunks == expected_chunks
947958

948959
# reblock on already blocked data
960+
orig_dask_names = get_dask_names(reblocked)
949961
reblocked = reblocked.chunk(expected_chunks)
962+
new_dask_names = get_dask_names(reblocked)
950963
assert reblocked.chunks == expected_chunks
951964
assert_identical(reblocked, data)
965+
# recuhnking with same chunk sizes should not change names
966+
for k, v in new_dask_names.items():
967+
assert v == orig_dask_names[k]
952968

953969
with raises_regex(ValueError, "some chunks"):
954970
data.chunk({"foo": 10})

0 commit comments

Comments
 (0)