Skip to content

Commit 793f55b

Browse files
authored
bpo-39218: Improve accuracy of variance calculation (GH-27960)
1 parent 044e8d8 commit 793f55b

File tree

3 files changed

+23
-14
lines changed

3 files changed

+23
-14
lines changed

Lib/statistics.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -728,15 +728,19 @@ def _ss(data, c=None):
728728
lead to garbage results.
729729
"""
730730
if c is not None:
731-
T, total, count = _sum((x-c)**2 for x in data)
731+
T, total, count = _sum((d := x - c) * d for x in data)
732732
return (T, total)
733+
# Compute the mean accurate to within 1/2 ulp
733734
c = mean(data)
734-
T, total, count = _sum((x-c)**2 for x in data)
735-
# The following sum should mathematically equal zero, but due to rounding
736-
# error may not.
737-
U, total2, count2 = _sum((x - c) for x in data)
738-
assert T == U and count == count2
739-
total -= total2 ** 2 / len(data)
735+
# Initial computation for the sum of square deviations
736+
T, total, count = _sum((d := x - c) * d for x in data)
737+
# Correct any remaining inaccuracy in the mean c.
738+
# The following sum should mathematically equal zero,
739+
# but due to the final rounding of the mean, it may not.
740+
U, error, count2 = _sum((x - c) for x in data)
741+
assert count == count2
742+
correction = error * error / len(data)
743+
total -= correction
740744
assert not total < 0, 'negative sum of square deviations: %f' % total
741745
return (T, total)
742746

@@ -924,8 +928,8 @@ def correlation(x, y, /):
924928
xbar = fsum(x) / n
925929
ybar = fsum(y) / n
926930
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
927-
sxx = fsum((xi - xbar) ** 2.0 for xi in x)
928-
syy = fsum((yi - ybar) ** 2.0 for yi in y)
931+
sxx = fsum((d := xi - xbar) * d for xi in x)
932+
syy = fsum((d := yi - ybar) * d for yi in y)
929933
try:
930934
return sxy / sqrt(sxx * syy)
931935
except ZeroDivisionError:
@@ -968,7 +972,7 @@ def linear_regression(x, y, /):
968972
xbar = fsum(x) / n
969973
ybar = fsum(y) / n
970974
sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
971-
sxx = fsum((xi - xbar) ** 2.0 for xi in x)
975+
sxx = fsum((d := xi - xbar) * d for xi in x)
972976
try:
973977
slope = sxy / sxx # equivalent to: covariance(x, y) / variance(x)
974978
except ZeroDivisionError:
@@ -1094,10 +1098,11 @@ def samples(self, n, *, seed=None):
10941098

10951099
def pdf(self, x):
10961100
"Probability density function. P(x <= X < x+dx) / dx"
1097-
variance = self._sigma ** 2.0
1101+
variance = self._sigma * self._sigma
10981102
if not variance:
10991103
raise StatisticsError('pdf() not defined when sigma is zero')
1100-
return exp((x - self._mu)**2.0 / (-2.0*variance)) / sqrt(tau*variance)
1104+
diff = x - self._mu
1105+
return exp(diff * diff / (-2.0 * variance)) / sqrt(tau * variance)
11011106

11021107
def cdf(self, x):
11031108
"Cumulative distribution function. P(X <= x)"
@@ -1161,7 +1166,7 @@ def overlap(self, other):
11611166
if not dv:
11621167
return 1.0 - erf(dm / (2.0 * X._sigma * sqrt(2.0)))
11631168
a = X._mu * Y_var - Y._mu * X_var
1164-
b = X._sigma * Y._sigma * sqrt(dm**2.0 + dv * log(Y_var / X_var))
1169+
b = X._sigma * Y._sigma * sqrt(dm * dm + dv * log(Y_var / X_var))
11651170
x1 = (a + b) / dv
11661171
x2 = (a - b) / dv
11671172
return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2)))
@@ -1204,7 +1209,7 @@ def stdev(self):
12041209
@property
12051210
def variance(self):
12061211
"Square of the standard deviation."
1207-
return self._sigma ** 2.0
1212+
return self._sigma * self._sigma
12081213

12091214
def __add__(x1, x2):
12101215
"""Add a constant or another NormalDist instance.

Lib/test/test_statistics.py

+3
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,9 @@ def __pow__(self, other):
12101210
def __add__(self, other):
12111211
return type(self)(super().__add__(other))
12121212
__radd__ = __add__
1213+
def __mul__(self, other):
1214+
return type(self)(super().__mul__(other))
1215+
__rmul__ = __mul__
12131216
return (float, Decimal, Fraction, MyFloat)
12141217

12151218
def test_types_conserved(self):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve accuracy of variance calculations by using x*x instead of x**2.

0 commit comments

Comments
 (0)