Skip to content

Commit 7b86a5e

Browse files
committed
apply_func: Set meta=np.ndarray when vectorize=True and dask="parallelized"
Closes pydata#3574
1 parent b3d3b44 commit 7b86a5e

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

xarray/core/computation.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def apply_variable_ufunc(
547547
output_dtypes=None,
548548
output_sizes=None,
549549
keep_attrs=False,
550+
vectorize=False,
550551
):
551552
"""Apply a ndarray level function over Variable and/or ndarray objects.
552553
"""
@@ -579,6 +580,7 @@ def apply_variable_ufunc(
579580
elif dask == "parallelized":
580581
input_dims = [broadcast_dims + dims for dims in signature.input_core_dims]
581582
numpy_func = func
583+
meta = np.ndarray if vectorize else None
582584

583585
def func(*arrays):
584586
return _apply_blockwise(
@@ -589,6 +591,7 @@ def func(*arrays):
589591
signature,
590592
output_dtypes,
591593
output_sizes,
594+
meta,
592595
)
593596

594597
elif dask == "allowed":
@@ -647,7 +650,14 @@ def func(*arrays):
647650

648651

649652
def _apply_blockwise(
650-
func, args, input_dims, output_dims, signature, output_dtypes, output_sizes=None
653+
func,
654+
args,
655+
input_dims,
656+
output_dims,
657+
signature,
658+
output_dtypes,
659+
output_sizes=None,
660+
meta=None,
651661
):
652662
import dask.array
653663

@@ -719,6 +729,7 @@ def _apply_blockwise(
719729
dtype=dtype,
720730
concatenate=True,
721731
new_axes=output_sizes,
732+
meta=meta,
722733
)
723734

724735

@@ -1005,6 +1016,7 @@ def earth_mover_distance(first_samples,
10051016
dask=dask,
10061017
output_dtypes=output_dtypes,
10071018
output_sizes=output_sizes,
1019+
vectorize=vectorize,
10081020
)
10091021

10101022
if any(isinstance(a, GroupBy) for a in args):

xarray/tests/test_computation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,24 @@ def test_vectorize_dask():
817817
assert_identical(expected, actual)
818818

819819

820+
@requires_dask
821+
def test_vectorize_dask_new_output_dims():
822+
# regression test for GH3574
823+
data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))
824+
func = lambda x: x[np.newaxis, ...]
825+
expected = data_array.expand_dims("z")
826+
actual = apply_ufunc(
827+
func,
828+
data_array.chunk({"x": 1}),
829+
output_core_dims=[["z"]],
830+
vectorize=True,
831+
dask="parallelized",
832+
output_dtypes=[float],
833+
output_sizes={"z": 1},
834+
).transpose(*expected.dims)
835+
assert_identical(expected, actual)
836+
837+
820838
def test_output_wrong_number():
821839
variable = xr.Variable("x", np.arange(10))
822840

0 commit comments

Comments
 (0)