Skip to content

Better rolling reductions #4915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 19, 2021
Merged

Conversation

dcherian
Copy link
Contributor

@dcherian dcherian commented Feb 16, 2021

Implements most of #4325 (comment)

%load_ext memory_profiler

import numpy as np
import xarray as xr

temp = xr.DataArray(np.zeros((5000, 500)), dims=("x", "y"))

roll = temp.rolling(x=10, y=20)

%memit roll.sum()
%memit roll.reduce(np.sum)
%memit roll.reduce(np.nansum)  # master  branch behaviour
peak memory: 245.18 MiB, increment: 81.92 MiB
peak memory: 226.09 MiB, increment: 62.69 MiB
peak memory: 4493.82 MiB, increment: 4330.43 MiB
  • xref Optimize ndrolling nanreduce #4325
  • asv benchmarks added
  • Passes pre-commit run --all-files
  • User visible changes (including notable bug fixes) are documented in whats-new.rst

@@ -494,6 +527,14 @@ def _numpy_or_bottleneck_reduce(
bottleneck_move_func, keep_attrs=keep_attrs, **kwargs
)
else:
if fillna is not None:
if fillna is dtypes.INF:
fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless since we always pad with NaN which ends up promoting to float. We should add fill_value support to rolling

Copy link
Collaborator

@mathause mathause left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever, looks good.

Copy link
Collaborator

@mathause mathause left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had another look and this looks ready (unless you'd like to also do mean)

@dcherian
Copy link
Contributor Author

I got mean to work. var is a little involved so haven't done that yet.

Copy link
Member

@fujiisoup fujiisoup left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clever implementation @dcherian !

I got mean to work.

Nice ;)

var is a little involved so haven't done that yet.

I think we can just leave the difficult ones with the TODO comment.

@dcherian
Copy link
Contributor Author

Thanks for the reviews @mathause and @fujiisoup

@dcherian dcherian merged commit 9a4313b into pydata:master Feb 19, 2021
@dcherian dcherian deleted the rolling-reductions branch February 19, 2021 19:44
@tbloch1
Copy link

tbloch1 commented Apr 12, 2023

Has there been any progress on this for var/std?

@dcherian
Copy link
Contributor Author

We would welcome a PR. Looking at the implementation of mean should help:

def _mean(self, keep_attrs, **kwargs):

@tbloch1
Copy link

tbloch1 commented Apr 13, 2023

I think I may have found a way to make it more memory efficient, but I don't know enough about writing the sort of code that would be needed for a PR.

I basically wrote out the calculation for variance trying to only use the functions that have already been optimsed. Derived from:

$$ var = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2 $$

$$ var = \frac{1}{n} \left( (x_1 - \mu)^2 + (x_2 - \mu)^2 + (x_3 - \mu)^2 + ... \right) $$

$$ var = \frac{1}{n} \left(x_1^2 -2x_1\mu + \mu^2 + \ x_2^2 -2x_2\mu + \mu^2 + \ x_3^2 -2x_3\mu + \mu^2 + ... \right) $$

$$ var = \frac{1}{n} \left( \sum_{i=1}^{n} x_i^2 - 2\mu\sum_{i=1}^{n} x_i + n\mu^2 \right)$$

I coded this up and demonstrate that it uses approximately 10% of the memory as the current .var() implementation:

%load_ext memory_profiler

import numpy as np
import xarray as xr

temp = xr.DataArray(np.random.randint(0, 10, (5000, 500)), dims=("x", "y"))

def new_var(da, x=10, y=20):
    # Defining the re-used parts
    roll = da.rolling(x=x, y=y)
    mean = roll.mean()
    count = roll.count()
    # First term: sum of squared values
    term1 = (da**2).rolling(x=x, y=y).sum()
    # Second term cross term sum
    term2 = -2 * mean * roll.sum()
    # Third term 'sum' of squared means
    term3 = count * mean**2
    # Combining into the variance
    var = (term1 + term2 + term3) / count
    return var

def old_var(da, x=10, y=20):
    roll = da.rolling(x=x, y=y)
    var = roll.var()
    return var

%memit new_var(temp)
%memit old_var(temp)
peak memory: 429.77 MiB, increment: 134.92 MiB
peak memory: 5064.07 MiB, increment: 4768.45 MiB

I wanted to double check that the calculation was working correctly:

print((var_o.where(~np.isnan(var_o), 0) == var_n.where(~np.isnan(var_n), 0)).all().values)
print(np.allclose(var_o, var_n, equal_nan = True))
False
True

I think the difference here is just due to floating point errors, but maybe someone who knows how to check that in more detail could have a look.

The standard deviation can be trivially implemented from this if the approach works.

@dcherian
Copy link
Contributor Author

Can you copy your comment to #4325 please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants