Skip to content

Commit 924bf69

Browse files
committed
review feedback:
1. skip index graph nodes. 2. var → name 3. quicker dataarray creation. 4. Add restrictions to docstring. 5. rename chunk construction task. 6. error when non-xarray object is returned. 7. restore non-coord dims. review
1 parent adbe48e commit 924bf69

File tree

2 files changed

+104
-41
lines changed

2 files changed

+104
-41
lines changed

xarray/core/parallel.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def make_meta(obj):
3535
from dask.array.utils import meta_from_array
3636

3737
if isinstance(obj, DataArray):
38-
meta = DataArray(obj.data._meta, dims=obj.dims)
38+
meta = DataArray(obj.data._meta, dims=obj.dims, name=obj.name)
3939

4040
if isinstance(obj, Dataset):
4141
meta = Dataset()
@@ -45,9 +45,14 @@ def make_meta(obj):
4545
else:
4646
meta_obj = meta_from_array(obj[name].data)
4747
meta[name] = DataArray(meta_obj, dims=obj[name].dims)
48+
# meta[name] = DataArray(obj[name].dims, meta_obj)
4849
else:
4950
meta = obj
5051

52+
# TODO: deal with non-dim coords
53+
# for coord_name in (set(obj.coords) - set(obj.dims)): # DataArrays should have _coord_names!
54+
# coord = obj[coord_name]
55+
5156
return meta
5257

5358

@@ -65,7 +70,7 @@ def infer_template(func, obj, *args, **kwargs):
6570
return template
6671

6772

68-
def _make_dict(x):
73+
def make_dict(x):
6974
# Dataset.to_dict() is too complicated
7075
# maps variable name to numpy array
7176
if isinstance(x, DataArray):
@@ -93,6 +98,9 @@ def map_blocks(func, obj, *args, **kwargs):
9398
properties of the returned object such as dtype, variable names,
9499
new dimensions and new indexes (if any).
95100
101+
This function must
102+
- return either a DataArray or a Dataset
103+
96104
This function cannot
97105
- change size of existing dimensions.
98106
- add new chunked dimensions.
@@ -101,18 +109,24 @@ def map_blocks(func, obj, *args, **kwargs):
101109
Chunks of this object will be provided to 'func'. The function must not change
102110
shape of the provided DataArray.
103111
args:
104-
Passed on to func.
112+
Passed on to func. Cannot include chunked xarray objects.
105113
kwargs:
106-
Passed on to func.
114+
Passed on to func. Cannot include chunked xarray objects.
107115
108116
109117
Returns
110118
-------
111119
DataArray or Dataset
112120
121+
Notes
122+
-----
123+
124+
This function is designed to work with dask-backed xarray objects. See apply_ufunc for
125+
a similar function that works with numpy arrays.
126+
113127
See Also
114128
--------
115-
dask.array.map_blocks
129+
dask.array.map_blocks, xarray.apply_ufunc
116130
"""
117131

118132
def _wrapper(func, obj, to_array, args, kwargs):
@@ -129,7 +143,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
129143
% name
130144
)
131145

132-
to_return = _make_dict(result)
146+
to_return = make_dict(result)
133147

134148
return to_return
135149

@@ -149,26 +163,30 @@ def _wrapper(func, obj, to_array, args, kwargs):
149163
if isinstance(template, DataArray):
150164
result_is_array = True
151165
template = template._to_temp_dataset()
152-
else:
166+
elif isinstance(template, Dataset):
153167
result_is_array = False
168+
else:
169+
raise ValueError(
170+
"Function must return an xarray DataArray or Dataset. Instead it returned %r"
171+
% type(template)
172+
)
154173

155174
# If two different variables have different chunking along the same dim
156175
# .chunks will raise an error.
157176
input_chunks = dataset.chunks
158177

159-
indexes = dict(dataset.indexes)
160-
for dim in template.indexes:
161-
if dim not in indexes:
162-
indexes[dim] = template.indexes[dim]
178+
# TODO: add a test that fails when template and dataset are switched
179+
indexes = dict(template.indexes)
180+
indexes.update(dataset.indexes)
163181

164182
graph = {}
165183
gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset))
166184

167185
# map dims to list of chunk indexes
168-
ichunk = {dim: range(len(input_chunks[dim])) for dim in input_chunks}
186+
ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()}
169187
# mapping from chunk index to slice bounds
170188
chunk_index_bounds = {
171-
dim: np.cumsum((0,) + input_chunks[dim]) for dim in input_chunks
189+
dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items()
172190
}
173191

174192
# iterate over all possible chunk combinations
@@ -185,17 +203,15 @@ def _wrapper(func, obj, to_array, args, kwargs):
185203
for name, variable in dataset.variables.items():
186204
# make a task that creates tuple of (dims, chunk)
187205
if dask.is_dask_collection(variable.data):
188-
var_dask_keys = variable.__dask_keys__()
189-
190206
# recursively index into dask_keys nested list to get chunk
191-
chunk = var_dask_keys
207+
chunk = variable.__dask_keys__()
192208
for dim in variable.dims:
193209
chunk = chunk[chunk_index_dict[dim]]
194210

195-
task_name = ("tuple-" + dask.base.tokenize(chunk),) + v
196-
graph[task_name] = (tuple, [variable.dims, chunk])
211+
chunk_variable_task = ("tuple-" + dask.base.tokenize(chunk),) + v
212+
graph[chunk_variable_task] = (tuple, [variable.dims, chunk])
197213
else:
198-
# numpy array with possibly chunked dimensions
214+
# non-dask array with possibly chunked dimensions
199215
# index into variable appropriately
200216
subsetter = dict()
201217
for dim in variable.dims:
@@ -207,14 +223,14 @@ def _wrapper(func, obj, to_array, args, kwargs):
207223
)
208224

209225
subset = variable.isel(subsetter)
210-
task_name = (name + dask.base.tokenize(subset),) + v
211-
graph[task_name] = (tuple, [subset.dims, subset])
226+
chunk_variable_task = (name + dask.base.tokenize(subset),) + v
227+
graph[chunk_variable_task] = (tuple, [subset.dims, subset])
212228

213229
# this task creates dict mapping variable name to above tuple
214-
if name in dataset.data_vars:
215-
data_vars.append([name, task_name])
216-
if name in dataset.coords:
217-
coords.append([name, task_name])
230+
if name in dataset._coord_names:
231+
coords.append([name, chunk_variable_task])
232+
else:
233+
data_vars.append([name, chunk_variable_task])
218234

219235
from_wrapper = (gname,) + v
220236
graph[from_wrapper] = (
@@ -229,14 +245,15 @@ def _wrapper(func, obj, to_array, args, kwargs):
229245
# mapping from variable name to dask graph key
230246
var_key_map = {}
231247
for name, variable in template.variables.items():
232-
var_dims = variable.dims
248+
if name in indexes:
249+
continue
233250
# cannot tokenize "name" because the hash of <this-array> is not invariant!
234251
# This happens when the user function does not set a name on the returned DataArray
235252
gname_l = "%s-%s" % (gname, name)
236253
var_key_map[name] = gname_l
237254

238255
key = (gname_l,)
239-
for dim in var_dims:
256+
for dim in variable.dims:
240257
if dim in chunk_index_dict:
241258
key += (chunk_index_dict[dim],)
242259
else:
@@ -248,26 +265,26 @@ def _wrapper(func, obj, to_array, args, kwargs):
248265
graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset])
249266

250267
result = Dataset()
251-
for var, key in var_key_map.items():
252-
# indexes need to be known
253-
# otherwise compute is called when DataArray is created
254-
if var in indexes:
255-
result[var] = indexes[var]
256-
continue
257-
258-
dims = template[var].dims
268+
# a quicker way to assign indexes?
269+
# indexes need to be known
270+
# otherwise compute is called when DataArray is created
271+
for name in template.indexes:
272+
result[name] = indexes[name]
273+
for name, key in var_key_map.items():
274+
dims = template[name].dims
259275
var_chunks = []
260276
for dim in dims:
261277
if dim in input_chunks:
262278
var_chunks.append(input_chunks[dim])
263-
else:
264-
if dim in indexes:
265-
var_chunks.append((len(indexes[dim]),))
279+
elif dim in indexes:
280+
var_chunks.append((len(indexes[dim]),))
266281

267282
data = dask.array.Array(
268-
graph, name=key, chunks=var_chunks, dtype=template[var].dtype
283+
graph, name=key, chunks=var_chunks, dtype=template[name].dtype
269284
)
270-
result[var] = DataArray(data=data, dims=dims, name=var)
285+
result[name] = (dims, data)
286+
287+
result = result.set_coords(template._coord_names)
271288

272289
if result_is_array:
273290
result = _to_array(result)

xarray/tests/test_dask.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def make_da():
891891

892892
def make_ds():
893893
map_ds = xr.Dataset()
894-
map_ds["a"] = map_da
894+
map_ds["a"] = make_da()
895895
map_ds["b"] = map_ds.a + 50
896896
map_ds["c"] = map_ds.x + 20
897897
map_ds = map_ds.chunk({"x": 4, "y": 5})
@@ -909,6 +909,18 @@ def make_ds():
909909
map_ds = make_ds()
910910

911911

912+
# DataArray.chunks is not a dict but Dataset.chunks is!
913+
def assert_chunks_equal(a, b):
914+
915+
if isinstance(a, DataArray):
916+
a = a._to_temp_dataset()
917+
918+
if isinstance(b, DataArray):
919+
b = b._to_temp_dataset()
920+
921+
assert a.chunks == b.chunks
922+
923+
912924
def simple_func(obj):
913925
result = obj.x + 5 * obj.y
914926
return result
@@ -933,6 +945,12 @@ def bad_func(darray):
933945
with raises_regex(ValueError, "Length of the.* has changed."):
934946
xr.map_blocks(bad_func, map_da).compute()
935947

948+
def returns_numpy(darray):
949+
return (darray * darray.x + 5 * darray.y).values
950+
951+
with raises_regex(ValueError, "Function must return an xarray DataArray"):
952+
xr.map_blocks(returns_numpy, map_da)
953+
936954

937955
@pytest.mark.parametrize(
938956
"func, obj",
@@ -942,6 +960,7 @@ def test_map_blocks(func, obj):
942960

943961
actual = xr.map_blocks(func, obj)
944962
expected = func(obj)
963+
assert_chunks_equal(expected, actual)
945964
xr.testing.assert_equal(expected, actual)
946965

947966

@@ -951,4 +970,31 @@ def test_map_blocks_args(obj):
951970

952971
expected = obj + 10
953972
actual = xr.map_blocks(operator.add, obj, 10)
973+
assert_chunks_equal(expected, actual)
954974
xr.testing.assert_equal(expected, actual)
975+
976+
977+
def da_to_ds(da):
978+
return da.to_dataset()
979+
980+
981+
def ds_to_da(ds):
982+
return ds.to_array()
983+
984+
985+
@pytest.mark.parametrize(
986+
"func, obj, return_type",
987+
[[da_to_ds, map_da, xr.Dataset], [ds_to_da, map_ds, xr.DataArray]],
988+
)
989+
def map_blocks_transformations(func, obj, return_type):
990+
assert isinstance(xr.map_blocks(func, obj), return_type)
991+
992+
993+
# func(DataArray) -> Dataset
994+
# func(Dataset) -> DataArray
995+
# func output contains less variables
996+
# func output contains new variables
997+
# func changes dtypes
998+
# func output contains less (or more) dimensions
999+
# *args, **kwargs are passed through
1000+
# IndexVariables don't accidentally cause the whole graph to be computed (the logic you wrote in the main function is quite subtle!)

0 commit comments

Comments
 (0)