diff --git a/xarray/__init__.py b/xarray/__init__.py index a3df034f7c7..87722756ebb 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -34,6 +34,8 @@ from .coding.cftime_offsets import cftime_range from .coding.cftimeindex import CFTimeIndex +from .core.parallel import map_blocks + from .util.print_versions import show_versions from . import tutorial diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py new file mode 100644 index 00000000000..f24575485c4 --- /dev/null +++ b/xarray/core/parallel.py @@ -0,0 +1,86 @@ +try: + import dask + import dask.array + from dask.highlevelgraph import HighLevelGraph + +except ImportError: + pass + +from .dataarray import DataArray + + +def get_chunk_slices(dataset): + chunk_slices = {} + for dim, chunks in dataset.chunks.items(): + slices = [] + start = 0 + for chunk in chunks: + stop = start + chunk + slices.append(slice(start, stop)) + start = stop + chunk_slices[dim] = slices + + return chunk_slices + + +def map_blocks(func, darray): + """ + A version of dask's map_blocks for DataArrays. + + Parameters + ---------- + func: callable + User-provided function that should accept DataArrays corresponding to one chunk. + darray: DataArray + Chunks of this array will be provided to 'func'. The function must not change + shape of the provided DataArray. + + Returns + ------- + DataArray + + See Also + -------- + dask.array.map_blocks + """ + + def _wrapper(darray): + result = func(darray) + if not isinstance(result, type(darray)): + raise ValueError("Result is not the same type as input.") + if result.shape != darray.shape: + raise ValueError("Result does not have the same shape as input.") + return result + + meta_array = DataArray(darray.data._meta, dims=darray.dims) + result_meta = func(meta_array) + + name = "%s-%s" % (darray.name or func.__name__, dask.base.tokenize(darray)) + + slicers = get_chunk_slices(darray._to_temp_dataset()) + dask_keys = list(dask.core.flatten(darray.__dask_keys__())) + + graph = { + (name,) + + (*key[1:],): ( + _wrapper, + ( + DataArray, + key, + { + dim_name: darray[dim_name][slicers[dim_name][index]] + for dim_name, index in zip(darray.dims, key[1:]) + }, + darray.dims, + ), + ) + for key in dask_keys + } + + graph = HighLevelGraph.from_collections(name, graph, dependencies=[darray]) + + return DataArray( + dask.array.Array(graph, name, chunks=darray.chunks, meta=result_meta), + dims=darray.dims, + coords=darray.coords, + ) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e3fc6f65e0f..8629010eec2 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -878,3 +878,25 @@ def test_dask_layers_and_dependencies(): assert set(x.foo.__dask_graph__().dependencies).issuperset( ds.__dask_graph__().dependencies ) + + +def test_map_blocks(): + darray = xr.DataArray( + dask.array.ones((10, 20), chunks=[4, 5]), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(100, 120)}, + ) + darray.name = None + + def good_func(darray): + return darray * darray.x + 5 * darray.y + + def bad_func(darray): + return (darray * darray.x + 5 * darray.y)[:1, :1] + + actual = xr.map_blocks(good_func, darray) + expected = good_func(darray) + xr.testing.assert_equal(expected, actual) + + with raises_regex(ValueError, "not have the same shape"): + xr.map_blocks(bad_func, darray).compute()