Skip to content

Strided rolling #3607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,8 @@ New Features
- Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen`
and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`)
By `Deepak Cherian <https://github.com/dcherian>`_
- :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` now have a stride option
By `Matthias Meyer <https://github.com/niowniow>`_.
- Add ``meta`` kwarg to :py:func:`~xarray.apply_ufunc`;
this is passed on to :py:func:`dask.array.blockwise`. (:pull:`3660`)
By `Deepak Cherian <https://github.com/dcherian>`_.
Expand Down
10 changes: 9 additions & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def rolling(
self,
dim: Mapping[Hashable, int] = None,
min_periods: int = None,
stride: int = 1,
center: Union[bool, Mapping[Hashable, bool]] = False,
keep_attrs: bool = None,
**window_kwargs: int,
Expand All @@ -838,6 +839,8 @@ def rolling(
setting min_periods equal to the size of the window.
center : bool or mapping, default: False
Set the labels at the center of the window.
stride : int, default 1
Stride of the moving window
**window_kwargs : optional
The keyword arguments form of ``dim``.
One of dim or window_kwargs must be provided.
Expand Down Expand Up @@ -890,7 +893,12 @@ def rolling(

dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
return self._rolling_cls(
self, dim, min_periods=min_periods, center=center, keep_attrs=keep_attrs
self,
dim,
min_periods=min_periods,
center=center,
stride=stride,
keep_attrs=keep_attrs,
)

def rolling_exp(
Expand Down
64 changes: 50 additions & 14 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,11 @@ def _get_keep_attrs(self, keep_attrs):


class DataArrayRolling(Rolling):
__slots__ = ("window_labels",)
__slots__ = ("window_labels", "stride")

def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
def __init__(
self, obj, windows, min_periods=None, center=False, stride=1, keep_attrs=None
):
"""
Moving window object for DataArray.
You should use DataArray.rolling() method to construct this object
Expand All @@ -221,6 +223,8 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
setting min_periods equal to the size of the window.
center : bool, default: False
Set the labels at the center of the window.
stride : int, default 1
Stride of the moving window

Returns
-------
Expand All @@ -237,16 +241,27 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs
)

# TODO legacy attribute
self.window_labels = self.obj[self.dim[0]]
if stride is None:
self.stride = 1
else:
self.stride = stride

window_labels = self.obj[self.dim[0]]
self.window_labels = window_labels[:: self.stride]

def __iter__(self):
if len(self.dim) > 1:
raise ValueError("__iter__ is only supported for 1d-rolling")
stops = np.arange(1, len(self.window_labels) + 1)
stops = np.arange(1, len(self.window_labels) * self.stride + 1)
starts = stops - int(self.window[0])
starts[: int(self.window[0])] = 0
for (label, start, stop) in zip(self.window_labels, starts, stops):

# apply striding
stops = stops[:: self.stride]
starts = starts[:: self.stride]
window_labels = self.window_labels

for (label, start, stop) in zip(window_labels, starts, stops):
window = self.obj.isel(**{self.dim[0]: slice(start, stop)})

counts = window.count(dim=self.dim[0])
Expand All @@ -257,7 +272,7 @@ def __iter__(self):
def construct(
self,
window_dim=None,
stride=1,
stride=None,
fill_value=dtypes.NA,
keep_attrs=None,
**window_dim_kwargs,
Expand Down Expand Up @@ -340,6 +355,9 @@ def _construct(
):
from .dataarray import DataArray

if stride is None:
stride = self.stride

keep_attrs = self._get_keep_attrs(keep_attrs)

if window_dim is None:
Expand Down Expand Up @@ -438,7 +456,11 @@ def reduce(self, func, keep_attrs=None, **kwargs):
else:
obj = self.obj
windows = self._construct(
obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
obj,
rolling_dim,
keep_attrs=keep_attrs,
fill_value=fillna,
stride=self.stride,
)

result = windows.reduce(
Expand Down Expand Up @@ -466,7 +488,9 @@ def _counts(self, keep_attrs):
center={d: self.center[i] for i, d in enumerate(self.dim)},
**{d: w for d, w in zip(self.dim, self.window)},
)
.construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs)
.construct(
rolling_dim, fill_value=False, stride=self.stride, keep_attrs=keep_attrs
)
.sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs)
)
return counts
Expand Down Expand Up @@ -509,6 +533,7 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
if self.center[0]:
values = values[valid]

values = values.isel(**{self.dim: slice(None, None, self.stride)})
attrs = self.obj.attrs if keep_attrs else {}

return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name)
Expand Down Expand Up @@ -557,9 +582,11 @@ def _numpy_or_bottleneck_reduce(


class DatasetRolling(Rolling):
__slots__ = ("rollings",)
__slots__ = ("rollings", "stride")

def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None):
def __init__(
self, obj, windows, min_periods=None, center=False, stride=1, keep_attrs=None
):
"""
Moving window object for Dataset.
You should use Dataset.rolling() method to construct this object
Expand All @@ -578,6 +605,8 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
setting min_periods equal to the size of the window.
center : bool or mapping of hashable to bool, default: False
Set the labels at the center of the window.
stride : int, default 1
Stride of the moving window

Returns
-------
Expand All @@ -593,6 +622,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None
super().__init__(obj, windows, min_periods, center, keep_attrs)
if any(d not in self.obj.dims for d in self.dim):
raise KeyError(self.dim)
self.stride = stride
# Keep each Rolling object as a dictionary
self.rollings = {}
for key, da in self.obj.data_vars.items():
Expand All @@ -605,7 +635,9 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None

if dims:
w = {d: windows[d] for d in dims}
self.rollings[key] = DataArrayRolling(da, w, min_periods, center)
self.rollings[key] = DataArrayRolling(
da, w, min_periods, center, stride=stride
)

def _dataset_implementation(self, func, keep_attrs, **kwargs):
from .dataset import Dataset
Expand All @@ -623,7 +655,9 @@ def _dataset_implementation(self, func, keep_attrs, **kwargs):
reduced[key].attrs = {}

attrs = self.obj.attrs if keep_attrs else {}
return Dataset(reduced, coords=self.obj.coords, attrs=attrs)
return Dataset(reduced, coords=self.obj.coords, attrs=attrs).isel(
**{self.dim: slice(None, None, self.stride)}
)

def reduce(self, func, keep_attrs=None, **kwargs):
"""Reduce the items in this group by applying `func` along some
Expand Down Expand Up @@ -680,7 +714,7 @@ def _numpy_or_bottleneck_reduce(
def construct(
self,
window_dim=None,
stride=1,
stride=None,
fill_value=dtypes.NA,
keep_attrs=None,
**window_dim_kwargs,
Expand Down Expand Up @@ -708,6 +742,8 @@ def construct(

from .dataset import Dataset

if stride is None:
stride = self.stride
keep_attrs = self._get_keep_attrs(keep_attrs)

if window_dim is None:
Expand Down
11 changes: 7 additions & 4 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,10 +805,13 @@ def _process_cmap_cbar_kwargs(
if func.__name__ == "surface":
# Leave user to specify cmap settings for surface plots
kwargs["cmap"] = cmap
return {
k: kwargs.get(k, None)
for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"]
}, {}
return (
{
k: kwargs.get(k, None)
for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"]
},
{},
)

cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs)

Expand Down
68 changes: 68 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6780,6 +6780,74 @@ def test_rolling_construct(center, window):
assert (da_rolling_mean == 0.0).sum() >= 0


@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("stride", (1, 2, None))
def test_rolling_stride(center, window, stride):
s = pd.Series(np.arange(10))
da = DataArray.from_series(s)

s_rolling = s.rolling(window, center=center, min_periods=1).mean()
da_rolling_strided = da.rolling(
index=window, center=center, min_periods=1, stride=stride
)

if stride is None:
stride_index = 1
else:
stride_index = stride

# with construct
da_rolling_mean = da_rolling_strided.construct("window").mean("window")
np.testing.assert_allclose(s_rolling.values[::stride_index], da_rolling_mean.values)
np.testing.assert_allclose(
s_rolling.index[::stride_index], da_rolling_mean["index"]
)
np.testing.assert_allclose(
s_rolling.index[::stride_index], da_rolling_mean["index"]
)

# with bottleneck
da_rolling_strided_mean = da_rolling_strided.mean()
np.testing.assert_allclose(
s_rolling.values[::stride_index], da_rolling_strided_mean.values
)
np.testing.assert_allclose(
s_rolling.index[::stride_index], da_rolling_strided_mean["index"]
)
np.testing.assert_allclose(
s_rolling.index[::stride_index], da_rolling_strided_mean["index"]
)

# with fill_value
da_rolling_mean = da_rolling_strided.construct("window", fill_value=0.0).mean(
"window"
)
assert da_rolling_mean.isnull().sum() == 0
assert (da_rolling_mean == 0.0).sum() >= 0

# with iter
assert len(da_rolling_strided.window_labels) == len(da["index"]) // stride_index
assert_identical(da_rolling_strided.window_labels, da["index"][::stride_index])

for i, (label, window_da) in enumerate(da_rolling_strided):
assert label == da["index"].isel(index=i * stride_index)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Mean of empty slice")
actual = da_rolling_strided_mean.isel(index=i)
expected = window_da.mean("index")

# TODO add assert_allclose_with_nan, which compares nan position
# as well as the closeness of the values.
assert_array_equal(actual.isnull(), expected.isnull())
if (~actual.isnull()).sum() > 0:
np.allclose(
actual.values,
expected.values,
)


@pytest.mark.parametrize("da", (1, 2), indirect=True)
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
Expand Down
55 changes: 55 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6677,6 +6677,61 @@ def test_rolling_construct(center, window):
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0


@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("stride", (1, 2, None))
def test_rolling_stride(center, window, stride):
df = pd.DataFrame(
{
"x": np.random.randn(20),
"y": np.random.randn(20),
"time": np.linspace(0, 1, 20),
}
)
ds = Dataset.from_dataframe(df)

df_rolling = df.rolling(window, center=center, min_periods=1).mean()
ds_rolling_strided = ds.rolling(
index=window, center=center, min_periods=1, stride=stride
)

if stride is None:
stride_index = 1
else:
stride_index = stride

# with construct
ds_rolling_mean = ds_rolling_strided.construct("window").mean("window")
np.testing.assert_allclose(
df_rolling["x"].values[::stride_index], ds_rolling_mean["x"].values
)
np.testing.assert_allclose(
df_rolling.index[::stride_index], ds_rolling_mean["index"]
)
np.testing.assert_allclose(
df_rolling.index[::stride_index], ds_rolling_mean["index"]
)

# with bottleneck
ds_rolling_strided_mean = ds_rolling_strided.mean()
np.testing.assert_allclose(
df_rolling["x"].values[::stride_index], ds_rolling_strided_mean["x"].values
)
np.testing.assert_allclose(
df_rolling.index[::stride_index], ds_rolling_strided_mean["index"]
)
np.testing.assert_allclose(
df_rolling.index[::stride_index], ds_rolling_strided_mean["index"]
)

# with fill_value
ds_rolling_mean = ds_rolling_strided.construct("window", fill_value=0.0).mean(
"window"
)
assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all()
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0


@pytest.mark.slow
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
@pytest.mark.parametrize("center", (True, False))
Expand Down