diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7310b984a2c..b11d85f180d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,10 @@ v0.9.2 (unreleased) Enhancements ~~~~~~~~~~~~ +- When bottleneck version 1.1 or later is installed, use bottleneck for rolling + `var`, `argmin`, `argmax`, and `rank` computations. Also, `rolling.median` + now also accepts a `min_periods` argument (:issue:`1276`). + By `Joe Hamman `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 20cc9cc3a96..34b63ebd094 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -52,7 +52,8 @@ NAN_CUM_METHODS = ['cumsum', 'cumprod'] BOTTLENECK_ROLLING_METHODS = {'move_sum': 'sum', 'move_mean': 'mean', 'move_std': 'std', 'move_min': 'min', - 'move_max': 'max'} + 'move_max': 'max', 'move_var': 'var', + 'move_argmin': 'argmin', 'move_argmax': 'argmax'} # TODO: wrap take, dot, sort @@ -520,24 +521,42 @@ def inject_bottleneck_rolling_methods(cls): for name, f in methods: func = cls._reduce_method(f) func.__name__ = name - func.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=func.__name__) + func.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format( + name=func.__name__) setattr(cls, name, func) # bottleneck rolling methods if has_bottleneck: - if LooseVersion(bn.__version__) < LooseVersion('1.0'): + # TODO: Bump the required version of bottlneck to 1.1 and remove all + # these version checks (see GH#1278) + bn_version = LooseVersion(bn.__version__) + bn_min_version = LooseVersion('1.0') + bn_version_1_1 = LooseVersion('1.1') + if bn_version < bn_min_version: return for bn_name, method_name in BOTTLENECK_ROLLING_METHODS.items(): - f = getattr(bn, bn_name) - func = cls._bottleneck_reduce(f) - func.__name__ = method_name - func.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=func.__name__) - setattr(cls, method_name, func) - - # bottleneck rolling methods without min_count + try: + f = getattr(bn, bn_name) + func = cls._bottleneck_reduce(f) + func.__name__ = method_name + func.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format( + name=func.__name__) + setattr(cls, method_name, func) + except AttributeError as e: + # skip functions not in Bottleneck 1.0 + if ((bn_version < bn_version_1_1) and + (bn_name not in ['move_var', 'move_argmin', + 'move_argmax', 'move_rank'])): + raise e + + # bottleneck rolling methods without min_count (bn.__version__ < 1.1) f = getattr(bn, 'move_median') - func = cls._bottleneck_reduce_without_min_count(f) + if bn_version >= bn_version_1_1: + func = cls._bottleneck_reduce(f) + else: + func = cls._bottleneck_reduce_without_min_count(f) func.__name__ = 'median' - func.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=func.__name__) + func.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format( + name=func.__name__) setattr(cls, 'median', func) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 06905208b58..c59aed97f1a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2448,16 +2448,24 @@ def test_rolling_properties(da): assert 'min_periods must be greater than zero' in str(exception) -@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'min', 'max', 'median')) +@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'var', + 'min', 'max', 'median')) @pytest.mark.parametrize('center', (True, False, None)) @pytest.mark.parametrize('min_periods', (1, None)) def test_rolling_wrapped_bottleneck(da, name, center, min_periods): pytest.importorskip('bottleneck') import bottleneck as bn - # skip if median and min_periods - if (min_periods == 1) and (name == 'median'): - pytest.skip() + # skip if median and min_periods bottleneck version < 1.1 + if ((min_periods == 1) and + (name == 'median') and + (LooseVersion(bn.__version__) < LooseVersion('1.1'))): + pytest.skip('rolling median accepts min_periods for bottleneck 1.1') + + # skip if var and bottleneck version < 1.1 + if ((name == 'median') and + (LooseVersion(bn.__version__) < LooseVersion('1.1'))): + pytest.skip('rolling median accepts min_periods for bottleneck 1.1') # Test all bottleneck functions rolling_obj = da.rolling(time=7, min_periods=min_periods) @@ -2473,8 +2481,12 @@ def test_rolling_wrapped_bottleneck(da, name, center, min_periods): actual = getattr(rolling_obj, name)()['time'] assert_equal(actual, da['time']) + def test_rolling_invalid_args(da): - pytest.importorskip('bottleneck') + pytest.importorskip('bottleneck', minversion="1.0") + import bottleneck as bn + if LooseVersion(bn.__version__) >= LooseVersion('1.1'): + pytest.skip('rolling median accepts min_periods for bottleneck 1.1') with pytest.raises(ValueError) as exception: da.rolling(time=7, min_periods=1).median() assert 'Rolling.median does not' in str(exception)