Skip to content

Commit f154c6e

Browse files
committed
ENH: Use Welford's method in stats.moments.rolling_var
This PR implements a modified version of Welford's method to compute the rolling variance. Instead of keeping track of the sum and sum of the squares of the items in the window, it tracks the mean and the sum of squared differences from the mean. This turns out to be (much) more numerically stable. The formulas to update these two variables when adding or removing an item from the sequence are well known, see e.g. http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance The formulas used when both adding one and removing one item I have not seen explicitly worked out anywhere, but are not too hard to come up with if you put pen to (a lot of) paper.
1 parent 4a37102 commit f154c6e

File tree

4 files changed

+67
-30
lines changed

4 files changed

+67
-30
lines changed

doc/source/release.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ Improvements to existing features
276276
- Add option to turn off escaping in ``DataFrame.to_latex`` (:issue:`6472`)
277277
- Added ``how`` option to rolling-moment functions to dictate how to handle resampling; :func:``rolling_max`` defaults to max,
278278
:func:``rolling_min`` defaults to min, and all others default to mean (:issue:`6297`)
279+
- ``pd.stats.moments.rolling_var`` now uses Welford's method for increased numerical stability (:issue:`6817`)
279280

280281
.. _release.bug_fixes-0.14.0:
281282

pandas/algos.pyx

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,56 +1122,82 @@ def nancorr_spearman(ndarray[float64_t, ndim=2] mat, Py_ssize_t minp=1):
11221122
# Rolling variance
11231123

11241124
def roll_var(ndarray[double_t] input, int win, int minp, int ddof=1):
1125-
cdef double val, prev, sum_x = 0, sum_xx = 0, nobs = 0
1125+
"""
1126+
Numerically stable implementation using Welford's method.
1127+
"""
1128+
cdef double val, prev, mean_x = 0, ssqdm_x = 0, nobs = 0, delta
11261129
cdef Py_ssize_t i
11271130
cdef Py_ssize_t N = len(input)
11281131

11291132
cdef ndarray[double_t] output = np.empty(N, dtype=float)
11301133

11311134
minp = _check_minp(win, minp, N)
11321135

1133-
for i from 0 <= i < minp - 1:
1136+
for i from 0 <= i < win:
11341137
val = input[i]
11351138

11361139
# Not NaN
11371140
if val == val:
11381141
nobs += 1
1139-
sum_x += val
1140-
sum_xx += val * val
1142+
delta = (val - mean_x)
1143+
mean_x += delta / nobs
1144+
ssqdm_x += delta * (val - mean_x)
11411145

1142-
output[i] = NaN
1146+
if nobs >= minp:
1147+
#pathological case
1148+
if nobs == 1:
1149+
val = 0
1150+
else:
1151+
val = ssqdm_x / (nobs - ddof)
1152+
if val < 0:
1153+
val = 0
1154+
else:
1155+
val = NaN
11431156

1144-
for i from minp - 1 <= i < N:
1157+
output[i] = val
1158+
1159+
for i from win <= i < N:
11451160
val = input[i]
1161+
prev = input[i - win]
11461162

11471163
if val == val:
1148-
nobs += 1
1149-
sum_x += val
1150-
sum_xx += val * val
1151-
1152-
if i > win - 1:
1153-
prev = input[i - win]
11541164
if prev == prev:
1155-
sum_x -= prev
1156-
sum_xx -= prev * prev
1157-
nobs -= 1
1165+
delta = val - prev
1166+
prev -= mean_x
1167+
mean_x += delta / nobs
1168+
val -= mean_x
1169+
ssqdm_x += (val + prev) * delta
1170+
else:
1171+
nobs += 1
1172+
delta = (val - mean_x)
1173+
mean_x += delta / nobs
1174+
ssqdm_x += delta * (val - mean_x)
1175+
elif prev == prev:
1176+
nobs -= 1
1177+
if nobs:
1178+
delta = (prev - mean_x)
1179+
mean_x -= delta / nobs
1180+
ssqdm_x -= delta * (prev - mean_x)
1181+
else:
1182+
mean_x = 0
1183+
ssqdm_x = 0
11581184

11591185
if nobs >= minp:
1160-
# pathological case
1186+
#pathological case
11611187
if nobs == 1:
1162-
output[i] = 0
1163-
continue
1164-
1165-
val = (nobs * sum_xx - sum_x * sum_x) / (nobs * (nobs - ddof))
1166-
if val < 0:
11671188
val = 0
1168-
1169-
output[i] = val
1189+
else:
1190+
val = ssqdm_x / (nobs - ddof)
1191+
if val < 0:
1192+
val = 0
11701193
else:
1171-
output[i] = NaN
1194+
val = NaN
1195+
1196+
output[i] = val
11721197

11731198
return output
11741199

1200+
11751201
#-------------------------------------------------------------------------------
11761202
# Rolling skewness
11771203

pandas/stats/moments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def rolling_window(arg, window=None, win_type=None, min_periods=None,
751751
* ``gaussian`` (needs std)
752752
* ``general_gaussian`` (needs power, width)
753753
* ``slepian`` (needs width).
754-
754+
755755
By default, the result is set to the right edge of the window. This can be
756756
changed to the center of the window by setting ``center=True``.
757757
@@ -978,7 +978,7 @@ def expanding_apply(arg, func, min_periods=1, freq=None, center=False,
978978
Returns
979979
-------
980980
y : type of input argument
981-
981+
982982
Notes
983983
-----
984984
The `freq` keyword is used to conform time series data to a specified

pandas/stats/tests/test_moments.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ def test_rolling_std_neg_sqrt(self):
295295

296296
def test_rolling_var(self):
297297
self._check_moment_func(mom.rolling_var,
298-
lambda x: np.var(x, ddof=1))
298+
lambda x: np.var(x, ddof=1),
299+
test_stable=True)
299300
self._check_moment_func(functools.partial(mom.rolling_var, ddof=0),
300301
lambda x: np.var(x, ddof=0))
301302

@@ -349,13 +350,15 @@ def _check_moment_func(self, func, static_comp, window=50,
349350
has_center=True,
350351
has_time_rule=True,
351352
preserve_nan=True,
352-
fill_value=None):
353+
fill_value=None,
354+
test_stable=False):
353355

354356
self._check_ndarray(func, static_comp, window=window,
355357
has_min_periods=has_min_periods,
356358
preserve_nan=preserve_nan,
357359
has_center=has_center,
358-
fill_value=fill_value)
360+
fill_value=fill_value,
361+
test_stable=test_stable)
359362

360363
self._check_structures(func, static_comp,
361364
has_min_periods=has_min_periods,
@@ -367,7 +370,8 @@ def _check_ndarray(self, func, static_comp, window=50,
367370
has_min_periods=True,
368371
preserve_nan=True,
369372
has_center=True,
370-
fill_value=None):
373+
fill_value=None,
374+
test_stable=False):
371375

372376
result = func(self.arr, window)
373377
assert_almost_equal(result[-1],
@@ -425,6 +429,12 @@ def _check_ndarray(self, func, static_comp, window=50,
425429
self.assert_(np.isnan(expected[-5]))
426430
self.assert_(np.isnan(result[-14]))
427431

432+
if test_stable:
433+
result = func(self.arr + 1e9, window)
434+
assert_almost_equal(result[-1],
435+
static_comp(self.arr[-50:] + 1e9))
436+
437+
428438
def _check_structures(self, func, static_comp,
429439
has_min_periods=True, has_time_rule=True,
430440
has_center=True,

0 commit comments

Comments
 (0)