Skip to content

Commit 4710af4

Browse files
committed
Use numbagg for ffill
1 parent 22ca9ba commit 4710af4

File tree

8 files changed

+91
-44
lines changed

8 files changed

+91
-44
lines changed

xarray/backends/zarr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
177177
# DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
178178
# this avoids the need to get involved in zarr synchronization / locking
179179
# From zarr docs:
180-
# "If each worker in a parallel computation is writing to a separate
181-
# region of the array, and if region boundaries are perfectly aligned
180+
# "If each worker in a parallel computation is writing to a
181+
# separate region of the array, and if region boundaries are perfectly aligned
182182
# with chunk boundaries, then no synchronization is required."
183183
# TODO: incorporate synchronizer to allow writes from multiple dask
184184
# threads

xarray/core/dask_array_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ def push(array, n, axis):
5959
"""
6060
Dask-aware bottleneck.push
6161
"""
62-
import bottleneck
6362
import dask.array as da
6463
import numpy as np
6564

65+
from xarray.core.duck_array_ops import _push
66+
6667
def _fill_with_last_one(a, b):
6768
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
6869
# the missing values using the last data of the previous chunk
@@ -85,7 +86,8 @@ def _fill_with_last_one(a, b):
8586

8687
# The method parameter makes that the tests for python 3.7 fails.
8788
return da.reductions.cumreduction(
88-
func=bottleneck.push,
89+
func=_push,
90+
# func=bottleneck.push,
8991
binop=_fill_with_last_one,
9092
ident=np.nan,
9193
x=array,

xarray/core/duck_array_ops.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
from numpy import concatenate as _concatenate
3232
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
3333
from numpy.lib.stride_tricks import sliding_window_view # noqa
34+
from packaging.version import Version
3435

3536
from xarray.core import dask_array_ops, dtypes, nputils
37+
from xarray.core.options import OPTIONS
3638
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
3739
from xarray.core.pycompat import array_type, is_duck_dask_array
3840
from xarray.core.utils import is_duck_array, module_available
@@ -688,13 +690,43 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
688690
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
689691

690692

691-
def push(array, n, axis):
692-
from bottleneck import push
693+
def _push(array, n: int | None = None, axis: int = -1):
694+
"""
695+
Use either bottleneck or numbagg depending on options & what's available
696+
"""
697+
from xarray.core.nputils import NUMBAGG_VERSION, numbagg
698+
699+
if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
700+
raise RuntimeError(
701+
"ffill & bfill requires bottleneck or numbagg to be enabled."
702+
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
703+
)
704+
if OPTIONS["use_numbagg"] and NUMBAGG_VERSION is not None:
705+
if NUMBAGG_VERSION < Version("0.6.2"):
706+
warnings.warn(
707+
f"numbagg >= 0.6.2 is required for bfill & ffill; {NUMBAGG_VERSION} is installed. We'll attempt with bottleneck instead."
708+
)
709+
else:
710+
return numbagg.ffill(array, limit=n, axis=axis)
693711

712+
# work around for bottleneck 178
713+
limit = n if n is not None else array.shape[axis]
714+
715+
import bottleneck as bn
716+
717+
return bn.push(array, limit, axis)
718+
719+
720+
def push(array, n, axis):
721+
if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
722+
raise RuntimeError(
723+
"ffill & bfill requires bottleneck or numbagg to be enabled."
724+
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
725+
)
694726
if is_duck_dask_array(array):
695727
return dask_array_ops.push(array, n, axis)
696728
else:
697-
return push(array, n, axis)
729+
return _push(array, n, axis)
698730

699731

700732
def _first_last_wrapper(array, *, axis, op, keepdims):

xarray/core/missing.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,6 @@ def _bfill(arr, n=None, axis=-1):
413413

414414
def ffill(arr, dim=None, limit=None):
415415
"""forward fill missing values"""
416-
if not OPTIONS["use_bottleneck"]:
417-
raise RuntimeError(
418-
"ffill requires bottleneck to be enabled."
419-
" Call `xr.set_options(use_bottleneck=True)` to enable it."
420-
)
421416

422417
axis = arr.get_axis_num(dim)
423418

@@ -436,11 +431,6 @@ def ffill(arr, dim=None, limit=None):
436431

437432
def bfill(arr, dim=None, limit=None):
438433
"""backfill missing values"""
439-
if not OPTIONS["use_bottleneck"]:
440-
raise RuntimeError(
441-
"bfill requires bottleneck to be enabled."
442-
" Call `xr.set_options(use_bottleneck=True)` to enable it."
443-
)
444434

445435
axis = arr.get_axis_num(dim)
446436

xarray/core/nputils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4+
from typing import Callable
45

56
import numpy as np
67
import pandas as pd
@@ -25,14 +26,17 @@
2526
bn = np
2627
_BOTTLENECK_AVAILABLE = False
2728

29+
NUMBAGG_VERSION: Version | None
30+
2831
try:
2932
import numbagg
3033

31-
_HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0")
34+
v = getattr(numbagg, "__version__", "999")
35+
NUMBAGG_VERSION = Version(v)
3236
except ImportError:
3337
# use numpy methods instead
3438
numbagg = np
35-
_HAS_NUMBAGG = False
39+
NUMBAGG_VERSION = None
3640

3741

3842
def _select_along_axis(values, idx, axis):
@@ -171,14 +175,15 @@ def __setitem__(self, key, value):
171175
self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions)
172176

173177

174-
def _create_method(name, npmodule=np):
178+
def _create_method(name, npmodule=np) -> Callable:
175179
def f(values, axis=None, **kwargs):
176180
dtype = kwargs.get("dtype", None)
177181
bn_func = getattr(bn, name, None)
178182
nba_func = getattr(numbagg, name, None)
179183

180184
if (
181-
_HAS_NUMBAGG
185+
NUMBAGG_VERSION is not None
186+
and NUMBAGG_VERSION >= Version("0.5.0")
182187
and OPTIONS["use_numbagg"]
183188
and isinstance(values, np.ndarray)
184189
and nba_func is not None

xarray/core/rolling_exp.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import numbagg
1616
from numbagg import move_exp_nanmean, move_exp_nansum
1717

18-
_NUMBAGG_VERSION: Version | None = Version(numbagg.__version__)
18+
NUMBAGG_VERSION: Version | None = Version(numbagg.__version__)
1919
except ImportError:
20-
_NUMBAGG_VERSION = None
20+
NUMBAGG_VERSION = None
2121

2222

2323
def _get_alpha(
@@ -100,17 +100,17 @@ def __init__(
100100
window_type: str = "span",
101101
min_weight: float = 0.0,
102102
):
103-
if _NUMBAGG_VERSION is None:
103+
if NUMBAGG_VERSION is None:
104104
raise ImportError(
105105
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
106106
)
107-
elif _NUMBAGG_VERSION < Version("0.2.1"):
107+
elif NUMBAGG_VERSION < Version("0.2.1"):
108108
raise ImportError(
109-
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {_NUMBAGG_VERSION} is installed"
109+
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {NUMBAGG_VERSION} is installed"
110110
)
111-
elif _NUMBAGG_VERSION < Version("0.3.1") and min_weight > 0:
111+
elif NUMBAGG_VERSION < Version("0.3.1") and min_weight > 0:
112112
raise ImportError(
113-
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {_NUMBAGG_VERSION} is installed"
113+
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {NUMBAGG_VERSION} is installed"
114114
)
115115

116116
self.obj: T_DataWithCoords = obj
@@ -211,9 +211,9 @@ def std(self) -> T_DataWithCoords:
211211
Dimensions without coordinates: x
212212
"""
213213

214-
if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
214+
if NUMBAGG_VERSION is None or NUMBAGG_VERSION < Version("0.4.0"):
215215
raise ImportError(
216-
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {_NUMBAGG_VERSION} is installed"
216+
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {NUMBAGG_VERSION} is installed"
217217
)
218218
dim_order = self.obj.dims
219219

@@ -243,9 +243,9 @@ def var(self) -> T_DataWithCoords:
243243
Dimensions without coordinates: x
244244
"""
245245

246-
if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
246+
if NUMBAGG_VERSION is None or NUMBAGG_VERSION < Version("0.4.0"):
247247
raise ImportError(
248-
f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {_NUMBAGG_VERSION} is installed"
248+
f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {NUMBAGG_VERSION} is installed"
249249
)
250250
dim_order = self.obj.dims
251251

@@ -275,9 +275,9 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
275275
Dimensions without coordinates: x
276276
"""
277277

278-
if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
278+
if NUMBAGG_VERSION is None or NUMBAGG_VERSION < Version("0.4.0"):
279279
raise ImportError(
280-
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
280+
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {NUMBAGG_VERSION} is installed"
281281
)
282282
dim_order = self.obj.dims
283283

@@ -308,9 +308,9 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
308308
Dimensions without coordinates: x
309309
"""
310310

311-
if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
311+
if NUMBAGG_VERSION is None or NUMBAGG_VERSION < Version("0.4.0"):
312312
raise ImportError(
313-
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
313+
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {NUMBAGG_VERSION} is installed"
314314
)
315315
dim_order = self.obj.dims
316316

xarray/tests/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def _importorskip(
5353
mod = importlib.import_module(modname)
5454
has = True
5555
if minversion is not None:
56-
if Version(mod.__version__) < Version(minversion):
56+
v = getattr(mod, "__version__", "999")
57+
if Version(v) < Version(minversion):
5758
raise ImportError("Minimum version not satisfied")
5859
except ImportError:
5960
has = False
@@ -89,6 +90,10 @@ def _importorskip(
8990
requires_scipy_or_netCDF4 = pytest.mark.skipif(
9091
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
9192
)
93+
has_numbagg_or_bottleneck = has_numbagg or has_bottleneck
94+
requires_numbagg_or_bottleneck = pytest.mark.skipif(
95+
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
96+
)
9297
# _importorskip does not work for development versions
9398
has_pandas_version_two = Version(pd.__version__).major >= 2
9499
requires_pandas_version_two = pytest.mark.skipif(

xarray/tests/test_missing.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
requires_bottleneck,
2525
requires_cftime,
2626
requires_dask,
27+
requires_numbagg,
28+
requires_numbagg_or_bottleneck,
2729
requires_scipy,
2830
)
2931

@@ -407,33 +409,44 @@ def test_interpolate_dask_expected_dtype(dtype, method):
407409
assert da.dtype == da.compute().dtype
408410

409411

410-
@requires_bottleneck
412+
@requires_numbagg_or_bottleneck
411413
def test_ffill():
412414
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
413415
expected = xr.DataArray(np.array([4, 5, 5], dtype=np.float64), dims="x")
414416
actual = da.ffill("x")
415417
assert_equal(actual, expected)
416418

417419

418-
def test_ffill_use_bottleneck():
420+
def test_ffill_use_bottleneck_numbagg():
419421
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
420422
with xr.set_options(use_bottleneck=False):
421-
with pytest.raises(RuntimeError):
422-
da.ffill("x")
423+
with xr.set_options(use_numbagg=False):
424+
with pytest.raises(RuntimeError):
425+
da.ffill("x")
423426

424427

425428
@requires_dask
426429
def test_ffill_use_bottleneck_dask():
427430
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
428431
da = da.chunk({"x": 1})
429-
with xr.set_options(use_bottleneck=False):
432+
with xr.set_options(use_bottleneck=False, use_numbagg=False):
430433
with pytest.raises(RuntimeError):
431434
da.ffill("x")
432435

433436

437+
@requires_numbagg
438+
@requires_dask
439+
def test_ffill_use_numbagg_dask():
440+
with xr.set_options(use_bottleneck=False):
441+
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
442+
da = da.chunk(x=-1)
443+
# Succeeds with a single chunk:
444+
_ = da.ffill("x").compute()
445+
446+
434447
def test_bfill_use_bottleneck():
435448
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
436-
with xr.set_options(use_bottleneck=False):
449+
with xr.set_options(use_bottleneck=False, use_numbagg=False):
437450
with pytest.raises(RuntimeError):
438451
da.bfill("x")
439452

@@ -442,7 +455,7 @@ def test_bfill_use_bottleneck():
442455
def test_bfill_use_bottleneck_dask():
443456
da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
444457
da = da.chunk({"x": 1})
445-
with xr.set_options(use_bottleneck=False):
458+
with xr.set_options(use_bottleneck=False, use_numbagg=False):
446459
with pytest.raises(RuntimeError):
447460
da.bfill("x")
448461

0 commit comments

Comments
 (0)