Skip to content

Use numpy & dask sliding_window_view for rolling #4977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions doc/user-guide/duckarrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ the code will still cast to ``numpy`` arrays:
:py:meth:`DataArray.interp` and :py:meth:`DataArray.interp_like` (uses ``scipy``):
duck arrays in data variables and non-dimension coordinates will be casted in
addition to not supporting duck arrays in dimension coordinates
* :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` (requires ``numpy>=1.20``)
* :py:meth:`Dataset.rolling_exp` and :py:meth:`DataArray.rolling_exp` (uses
``numbagg``)
* :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` (uses internal functions
of ``numpy``)
* :py:meth:`Dataset.interpolate_na` and :py:meth:`DataArray.interpolate_na` (uses
:py:class:`numpy.vectorize`)
* :py:func:`apply_ufunc` with ``vectorize=True`` (uses :py:class:`numpy.vectorize`)
Expand Down
129 changes: 129 additions & 0 deletions xarray/core/dask_array_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,132 @@ def nanmedian(a, axis=None, keepdims=False):
)

return result


if LooseVersion(dask_version) > LooseVersion("2.30.0"):
ensure_minimum_chunksize = da.overlap.ensure_minimum_chunksize
else:

# copied from dask
def ensure_minimum_chunksize(size, chunks):
"""Determine new chunks to ensure that every chunk >= size

Parameters
----------
size: int
The maximum size of any chunk.
chunks: tuple
Chunks along one axis, e.g. ``(3, 3, 2)``

Examples
--------
>>> ensure_minimum_chunksize(10, (20, 20, 1))
(20, 11, 10)
>>> ensure_minimum_chunksize(3, (1, 1, 3))
(5,)

See Also
--------
overlap
"""
if size <= min(chunks):
return chunks

# add too-small chunks to chunks before them
output = []
new = 0
for c in chunks:
if c < size:
if new > size + (size - c):
output.append(new - (size - c))
new = size
else:
new += c
if new >= size:
output.append(new)
new = 0
if c >= size:
new += c
if new >= size:
output.append(new)
elif len(output) >= 1:
output[-1] += new
else:
raise ValueError(
f"The overlapping depth {size} is larger than your "
f"array {sum(chunks)}."
)

return tuple(output)


if LooseVersion(dask_version) > LooseVersion("2021.03.0"):
sliding_window_view = da.lib.stride_tricks.sliding_window_view
else:

def sliding_window_view(x, window_shape, axis=None):
from dask.array.overlap import map_overlap
from numpy.core.numeric import normalize_axis_tuple # type: ignore

from .npcompat import sliding_window_view as _np_sliding_window_view

window_shape = (
tuple(window_shape) if np.iterable(window_shape) else (window_shape,)
)

window_shape_array = np.array(window_shape)
if np.any(window_shape_array <= 0):
raise ValueError("`window_shape` must contain positive values")

if axis is None:
axis = tuple(range(x.ndim))
if len(window_shape) != len(axis):
raise ValueError(
f"Since axis is `None`, must provide "
f"window_shape for all dimensions of `x`; "
f"got {len(window_shape)} window_shape elements "
f"and `x.ndim` is {x.ndim}."
)
else:
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
if len(window_shape) != len(axis):
raise ValueError(
f"Must provide matching length window_shape and "
f"axis; got {len(window_shape)} window_shape "
f"elements and {len(axis)} axes elements."
)

depths = [0] * x.ndim
for ax, window in zip(axis, window_shape):
depths[ax] += window - 1

# Ensure that each chunk is big enough to leave at least a size-1 chunk
# after windowing (this is only really necessary for the last chunk).
safe_chunks = tuple(
ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks)
)
x = x.rechunk(safe_chunks)

# result.shape = x_shape_trimmed + window_shape,
# where x_shape_trimmed is x.shape with every entry
# reduced by one less than the corresponding window size.
# trim chunks to match x_shape_trimmed
newchunks = tuple(
c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks)
) + tuple((window,) for window in window_shape)

kwargs = dict(
depth=tuple((0, d) for d in depths), # Overlap on +ve side only
boundary="none",
meta=x._meta,
new_axis=range(x.ndim, x.ndim + len(axis)),
chunks=newchunks,
trim=False,
window_shape=window_shape,
axis=axis,
)
# map_overlap's signature changed in https://github.com/dask/dask/pull/6165
if LooseVersion(dask_version) > "2.18.0":
return map_overlap(_np_sliding_window_view, x, align_arrays=False, **kwargs)
else:
return map_overlap(x, _np_sliding_window_view, **kwargs)
88 changes: 0 additions & 88 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import numpy as np

from . import dtypes, nputils


Expand All @@ -26,92 +24,6 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
return result


def rolling_window(a, axis, window, center, fill_value):
"""Dask's equivalence to np.utils.rolling_window"""
import dask.array as da

if not hasattr(axis, "__len__"):
axis = [axis]
window = [window]
center = [center]

orig_shape = a.shape
depth = {d: 0 for d in range(a.ndim)}
offset = [0] * a.ndim
drop_size = [0] * a.ndim
pad_size = [0] * a.ndim
for ax, win, cent in zip(axis, window, center):
if ax < 0:
ax = a.ndim + ax
depth[ax] = int(win / 2)
# For evenly sized window, we need to crop the first point of each block.
offset[ax] = 1 if win % 2 == 0 else 0

if depth[ax] > min(a.chunks[ax]):
raise ValueError(
"For window size %d, every chunk should be larger than %d, "
"but the smallest chunk size is %d. Rechunk your array\n"
"with a larger chunk size or a chunk size that\n"
"more evenly divides the shape of your array."
% (win, depth[ax], min(a.chunks[ax]))
)

# Although da.overlap pads values to boundaries of the array,
# the size of the generated array is smaller than what we want
# if center == False.
if cent:
start = int(win / 2) # 10 -> 5, 9 -> 4
end = win - 1 - start
else:
start, end = win - 1, 0
pad_size[ax] = max(start, end) + offset[ax] - depth[ax]
drop_size[ax] = 0
# pad_size becomes more than 0 when the overlapped array is smaller than
# needed. In this case, we need to enlarge the original array by padding
# before overlapping.
if pad_size[ax] > 0:
if pad_size[ax] < depth[ax]:
# overlapping requires each chunk larger than depth. If pad_size is
# smaller than the depth, we enlarge this and truncate it later.
drop_size[ax] = depth[ax] - pad_size[ax]
pad_size[ax] = depth[ax]

# TODO maybe following two lines can be summarized.
a = da.pad(
a, [(p, 0) for p in pad_size], mode="constant", constant_values=fill_value
)
boundary = {d: fill_value for d in range(a.ndim)}

# create overlap arrays
ag = da.overlap.overlap(a, depth=depth, boundary=boundary)

def func(x, window, axis):
x = np.asarray(x)
index = [slice(None)] * x.ndim
for ax, win in zip(axis, window):
x = nputils._rolling_window(x, win, ax)
index[ax] = slice(offset[ax], None)
return x[tuple(index)]

chunks = list(a.chunks) + window
new_axis = [a.ndim + i for i in range(len(axis))]
out = da.map_blocks(
func,
ag,
dtype=a.dtype,
new_axis=new_axis,
chunks=chunks,
window=window,
axis=axis,
)

# crop boundary.
index = [slice(None)] * a.ndim
for ax in axis:
index[ax] = slice(drop_size[ax], drop_size[ax] + orig_shape[ax])
return out[tuple(index)]


def least_squares(lhs, rhs, rcond=None, skipna=False):
import dask.array as da

Expand Down
8 changes: 4 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,15 +614,15 @@ def last(values, axis, skipna=None):
return take(values, -1, axis=axis)


def rolling_window(array, axis, window, center, fill_value):
def sliding_window_view(array, window_shape, axis):
"""
Make an ndarray with a rolling window of axis-th dimension.
The rolling dimension will be placed at the last dimension.
"""
if is_duck_dask_array(array):
return dask_array_ops.rolling_window(array, axis, window, center, fill_value)
else: # np.ndarray
return nputils.rolling_window(array, axis, window, center, fill_value)
return dask_array_compat.sliding_window_view(array, window_shape, axis)
else:
return npcompat.sliding_window_view(array, window_shape, axis)


def least_squares(lhs, rhs, rcond=None, skipna=False):
Expand Down
97 changes: 97 additions & 0 deletions xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import builtins
import operator
from distutils.version import LooseVersion
from typing import Union

import numpy as np
Expand Down Expand Up @@ -96,3 +97,99 @@ def __array_function__(self, *args, **kwargs):


IS_NEP18_ACTIVE = _is_nep18_active()


if LooseVersion(np.__version__) >= "1.20.0":
sliding_window_view = np.lib.stride_tricks.sliding_window_view
else:
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from numpy.lib.stride_tricks import as_strided

# copied from numpy.lib.stride_tricks
def sliding_window_view(
x, window_shape, axis=None, *, subok=False, writeable=False
):
"""
Create a sliding window view into the array with the given window shape.

Also known as rolling or moving window, the window slides across all
dimensions of the array and extracts subsets of the array at all window
positions.

.. versionadded:: 1.20.0

Parameters
----------
x : array_like
Array to create the sliding window view from.
window_shape : int or tuple of int
Size of window over each axis that takes part in the sliding window.
If `axis` is not present, must have same length as the number of input
array dimensions. Single integers `i` are treated as if they were the
tuple `(i,)`.
axis : int or tuple of int, optional
Axis or axes along which the sliding window is applied.
By default, the sliding window is applied to all axes and
`window_shape[i]` will refer to axis `i` of `x`.
If `axis` is given as a `tuple of int`, `window_shape[i]` will refer to
the axis `axis[i]` of `x`.
Single integers `i` are treated as if they were the tuple `(i,)`.
subok : bool, optional
If True, sub-classes will be passed-through, otherwise the returned
array will be forced to be a base-class array (default).
writeable : bool, optional
When true, allow writing to the returned view. The default is false,
as this should be used with caution: the returned view contains the
same memory location multiple times, so writing to one location will
cause others to change.

Returns
-------
view : ndarray
Sliding window view of the array. The sliding window dimensions are
inserted at the end, and the original dimensions are trimmed as
required by the size of the sliding window.
That is, ``view.shape = x_shape_trimmed + window_shape``, where
``x_shape_trimmed`` is ``x.shape`` with every entry reduced by one less
than the corresponding window size.
"""
window_shape = (
tuple(window_shape) if np.iterable(window_shape) else (window_shape,)
)
# first convert input to array, possibly keeping subclass
x = np.array(x, copy=False, subok=subok)

window_shape_array = np.array(window_shape)
if np.any(window_shape_array < 0):
raise ValueError("`window_shape` cannot contain negative values")

if axis is None:
axis = tuple(range(x.ndim))
if len(window_shape) != len(axis):
raise ValueError(
f"Since axis is `None`, must provide "
f"window_shape for all dimensions of `x`; "
f"got {len(window_shape)} window_shape elements "
f"and `x.ndim` is {x.ndim}."
)
else:
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
if len(window_shape) != len(axis):
raise ValueError(
f"Must provide matching length window_shape and "
f"axis; got {len(window_shape)} window_shape "
f"elements and {len(axis)} axes elements."
)

out_strides = x.strides + tuple(x.strides[ax] for ax in axis)

# note: same axis can be windowed repeatedly
x_shape_trimmed = list(x.shape)
for ax, dim in zip(axis, window_shape):
if x_shape_trimmed[ax] < dim:
raise ValueError("window shape cannot be larger than input array shape")
x_shape_trimmed[ax] -= dim - 1
out_shape = tuple(x_shape_trimmed) + window_shape
return as_strided(
x, strides=out_strides, shape=out_shape, subok=subok, writeable=writeable
)
Loading