Skip to content

Commit b3d3b44

Browse files
authored
Add nanmedian for dask arrays (pydata#3604)
* Add nanmedian for dask arrays Close pydata#2999 * Fix tests. * fix import * Make sure that we don't rechunk the entire variable to one chunk by reducing over all dimensions. Dask raises an error when axis=None but not when axis=range(a.ndim). * fix tests. * Update whats-new.rst
1 parent cc22f41 commit b3d3b44

File tree

5 files changed

+102
-8
lines changed

5 files changed

+102
-8
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ Breaking changes
2525

2626
New Features
2727
~~~~~~~~~~~~
28+
- Implement :py:func:`median` and :py:func:`nanmedian` for dask arrays. This works by rechunking
29+
to a single chunk along all reduction axes. (:issue:`2999`).
30+
By `Deepak Cherian <https://github.com/dcherian>`_.
2831
- :py:func:`xarray.concat` now preserves attributes from the first Variable.
2932
(:issue:`2575`, :issue:`2060`, :issue:`1614`)
3033
By `Deepak Cherian <https://github.com/dcherian>`_.

xarray/core/dask_array_compat.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from distutils.version import LooseVersion
2+
from typing import Iterable
23

3-
import dask.array as da
44
import numpy as np
5-
from dask import __version__ as dask_version
5+
6+
try:
7+
import dask.array as da
8+
from dask import __version__ as dask_version
9+
except ImportError:
10+
dask_version = "0.0.0"
11+
da = None
612

713
if LooseVersion(dask_version) >= LooseVersion("2.0.0"):
814
meta_from_array = da.utils.meta_from_array
@@ -89,3 +95,76 @@ def meta_from_array(x, ndim=None, dtype=None):
8995
meta = meta.astype(dtype)
9096

9197
return meta
98+
99+
100+
if LooseVersion(dask_version) >= LooseVersion("2.8.1"):
101+
median = da.median
102+
else:
103+
# Copied from dask v2.8.1
104+
# Used under the terms of Dask's license, see licenses/DASK_LICENSE.
105+
def median(a, axis=None, keepdims=False):
106+
"""
107+
This works by automatically chunking the reduced axes to a single chunk
108+
and then calling ``numpy.median`` function across the remaining dimensions
109+
"""
110+
111+
if axis is None:
112+
raise NotImplementedError(
113+
"The da.median function only works along an axis. "
114+
"The full algorithm is difficult to do in parallel"
115+
)
116+
117+
if not isinstance(axis, Iterable):
118+
axis = (axis,)
119+
120+
axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
121+
122+
a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
123+
124+
result = a.map_blocks(
125+
np.median,
126+
axis=axis,
127+
keepdims=keepdims,
128+
drop_axis=axis if not keepdims else None,
129+
chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)]
130+
if keepdims
131+
else None,
132+
)
133+
134+
return result
135+
136+
137+
if LooseVersion(dask_version) > LooseVersion("2.9.0"):
138+
nanmedian = da.nanmedian
139+
else:
140+
141+
def nanmedian(a, axis=None, keepdims=False):
142+
"""
143+
This works by automatically chunking the reduced axes to a single chunk
144+
and then calling ``numpy.nanmedian`` function across the remaining dimensions
145+
"""
146+
147+
if axis is None:
148+
raise NotImplementedError(
149+
"The da.nanmedian function only works along an axis. "
150+
"The full algorithm is difficult to do in parallel"
151+
)
152+
153+
if not isinstance(axis, Iterable):
154+
axis = (axis,)
155+
156+
axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
157+
158+
a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
159+
160+
result = a.map_blocks(
161+
np.nanmedian,
162+
axis=axis,
163+
keepdims=keepdims,
164+
drop_axis=axis if not keepdims else None,
165+
chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)]
166+
if keepdims
167+
else None,
168+
)
169+
170+
return result

xarray/core/duck_array_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import pandas as pd
1313

14-
from . import dask_array_ops, dtypes, npcompat, nputils
14+
from . import dask_array_ops, dask_array_compat, dtypes, npcompat, nputils
1515
from .nputils import nanfirst, nanlast
1616
from .pycompat import dask_array_type
1717

@@ -284,7 +284,7 @@ def _ignore_warnings_if(condition):
284284
yield
285285

286286

287-
def _create_nan_agg_method(name, coerce_strings=False):
287+
def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False):
288288
from . import nanops
289289

290290
def f(values, axis=None, skipna=None, **kwargs):
@@ -301,7 +301,7 @@ def f(values, axis=None, skipna=None, **kwargs):
301301
nanname = "nan" + name
302302
func = getattr(nanops, nanname)
303303
else:
304-
func = _dask_or_eager_func(name)
304+
func = _dask_or_eager_func(name, dask_module=dask_module)
305305

306306
try:
307307
return func(values, axis=axis, **kwargs)
@@ -337,7 +337,7 @@ def f(values, axis=None, skipna=None, **kwargs):
337337
std.numeric_only = True
338338
var = _create_nan_agg_method("var")
339339
var.numeric_only = True
340-
median = _create_nan_agg_method("median")
340+
median = _create_nan_agg_method("median", dask_module=dask_array_compat)
341341
median.numeric_only = True
342342
prod = _create_nan_agg_method("prod")
343343
prod.numeric_only = True

xarray/core/nanops.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
try:
88
import dask.array as dask_array
9+
from . import dask_array_compat
910
except ImportError:
1011
dask_array = None
12+
dask_array_compat = None # type: ignore
1113

1214

1315
def _replace_nan(a, val):
@@ -141,7 +143,15 @@ def nanmean(a, axis=None, dtype=None, out=None):
141143

142144

143145
def nanmedian(a, axis=None, out=None):
144-
return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis)
146+
# The dask algorithm works by rechunking to one chunk along axis
147+
# Make sure we trigger the dask error when passing all dimensions
148+
# so that we don't rechunk the entire array to one chunk and
149+
# possibly blow memory
150+
if axis is not None and len(np.atleast_1d(axis)) == a.ndim:
151+
axis = None
152+
return _dask_or_eager_func(
153+
"nanmedian", dask_module=dask_array_compat, eager_module=nputils
154+
)(a, axis=axis)
145155

146156

147157
def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs):

xarray/tests/test_dask.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,10 @@ def test_reduce(self):
216216
self.assertLazyAndAllClose(u.argmin(dim="x"), actual)
217217
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
218218
self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
219-
with raises_regex(NotImplementedError, "dask"):
219+
with raises_regex(NotImplementedError, "only works along an axis"):
220220
v.median()
221+
with raises_regex(NotImplementedError, "only works along an axis"):
222+
v.median(v.dims)
221223
with raise_if_dask_computes():
222224
v.reduce(duck_array_ops.mean)
223225

0 commit comments

Comments
 (0)