Skip to content

Commit 3c30805

Browse files
authored
[3.10] bpo-20499: Rounding error in statistics.pvariance (GH-28230) (GH-28248)
1 parent 6b996d6 commit 3c30805

File tree

3 files changed

+51
-52
lines changed

3 files changed

+51
-52
lines changed

Lib/statistics.py

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -147,21 +147,17 @@ class StatisticsError(ValueError):
147147

148148
# === Private utilities ===
149149

150-
def _sum(data, start=0):
151-
"""_sum(data [, start]) -> (type, sum, count)
150+
def _sum(data):
151+
"""_sum(data) -> (type, sum, count)
152152
153153
Return a high-precision sum of the given numeric data as a fraction,
154154
together with the type to be converted to and the count of items.
155155
156-
If optional argument ``start`` is given, it is added to the total.
157-
If ``data`` is empty, ``start`` (defaulting to 0) is returned.
158-
159-
160156
Examples
161157
--------
162158
163-
>>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
164-
(<class 'float'>, Fraction(11, 1), 5)
159+
>>> _sum([3, 2.25, 4.5, -0.5, 0.25])
160+
(<class 'float'>, Fraction(19, 2), 5)
165161
166162
Some sources of round-off error will be avoided:
167163
@@ -184,10 +180,9 @@ def _sum(data, start=0):
184180
allowed.
185181
"""
186182
count = 0
187-
n, d = _exact_ratio(start)
188-
partials = {d: n}
183+
partials = {}
189184
partials_get = partials.get
190-
T = _coerce(int, type(start))
185+
T = int
191186
for typ, values in groupby(data, type):
192187
T = _coerce(T, typ) # or raise TypeError
193188
for n, d in map(_exact_ratio, values):
@@ -200,8 +195,7 @@ def _sum(data, start=0):
200195
assert not _isfinite(total)
201196
else:
202197
# Sum all the partial sums using builtin sum.
203-
# FIXME is this faster if we sum them in order of the denominator?
204-
total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
198+
total = sum(Fraction(n, d) for d, n in partials.items())
205199
return (T, total, count)
206200

207201

@@ -252,27 +246,19 @@ def _exact_ratio(x):
252246
x is expected to be an int, Fraction, Decimal or float.
253247
"""
254248
try:
255-
# Optimise the common case of floats. We expect that the most often
256-
# used numeric type will be builtin floats, so try to make this as
257-
# fast as possible.
258-
if type(x) is float or type(x) is Decimal:
259-
return x.as_integer_ratio()
260-
try:
261-
# x may be an int, Fraction, or Integral ABC.
262-
return (x.numerator, x.denominator)
263-
except AttributeError:
264-
try:
265-
# x may be a float or Decimal subclass.
266-
return x.as_integer_ratio()
267-
except AttributeError:
268-
# Just give up?
269-
pass
249+
return x.as_integer_ratio()
250+
except AttributeError:
251+
pass
270252
except (OverflowError, ValueError):
271253
# float NAN or INF.
272254
assert not _isfinite(x)
273255
return (x, None)
274-
msg = "can't convert type '{}' to numerator/denominator"
275-
raise TypeError(msg.format(type(x).__name__))
256+
try:
257+
# x may be an Integral ABC.
258+
return (x.numerator, x.denominator)
259+
except AttributeError:
260+
msg = f"can't convert type '{type(x).__name__}' to numerator/denominator"
261+
raise TypeError(msg)
276262

277263

278264
def _convert(value, T):
@@ -719,14 +705,20 @@ def _ss(data, c=None):
719705
if c is not None:
720706
T, total, count = _sum((x-c)**2 for x in data)
721707
return (T, total)
722-
c = mean(data)
723-
T, total, count = _sum((x-c)**2 for x in data)
724-
# The following sum should mathematically equal zero, but due to rounding
725-
# error may not.
726-
U, total2, count2 = _sum((x - c) for x in data)
727-
assert T == U and count == count2
728-
total -= total2 ** 2 / len(data)
729-
assert not total < 0, 'negative sum of square deviations: %f' % total
708+
T, total, count = _sum(data)
709+
mean_n, mean_d = (total / count).as_integer_ratio()
710+
partials = Counter()
711+
for n, d in map(_exact_ratio, data):
712+
diff_n = n * mean_d - d * mean_n
713+
diff_d = d * mean_d
714+
partials[diff_d * diff_d] += diff_n * diff_n
715+
if None in partials:
716+
# The sum will be a NAN or INF. We can ignore all the finite
717+
# partials, and just look at this special one.
718+
total = partials[None]
719+
assert not _isfinite(total)
720+
else:
721+
total = sum(Fraction(n, d) for d, n in partials.items())
730722
return (T, total)
731723

732724

@@ -830,6 +822,9 @@ def stdev(data, xbar=None):
830822
1.0810874155219827
831823
832824
"""
825+
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
826+
# remain because there are two rounding steps. The first occurs in
827+
# the _convert() step for variance(), the second occurs in math.sqrt().
833828
var = variance(data, xbar)
834829
try:
835830
return var.sqrt()
@@ -846,6 +841,9 @@ def pstdev(data, mu=None):
846841
0.986893273527251
847842
848843
"""
844+
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
845+
# remain because there are two rounding steps. The first occurs in
846+
# the _convert() step for pvariance(), the second occurs in math.sqrt().
849847
var = pvariance(data, mu)
850848
try:
851849
return var.sqrt()

Lib/test/test_statistics.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,20 +1247,14 @@ def test_empty_data(self):
12471247
# Override test for empty data.
12481248
for data in ([], (), iter([])):
12491249
self.assertEqual(self.func(data), (int, Fraction(0), 0))
1250-
self.assertEqual(self.func(data, 23), (int, Fraction(23), 0))
1251-
self.assertEqual(self.func(data, 2.3), (float, Fraction(2.3), 0))
12521250

12531251
def test_ints(self):
12541252
self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
12551253
(int, Fraction(60), 8))
1256-
self.assertEqual(self.func([4, 2, 3, -8, 7], 1000),
1257-
(int, Fraction(1008), 5))
12581254

12591255
def test_floats(self):
12601256
self.assertEqual(self.func([0.25]*20),
12611257
(float, Fraction(5.0), 20))
1262-
self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5),
1263-
(float, Fraction(3.125), 4))
12641258

12651259
def test_fractions(self):
12661260
self.assertEqual(self.func([Fraction(1, 1000)]*500),
@@ -1281,14 +1275,6 @@ def test_compare_with_math_fsum(self):
12811275
data = [random.uniform(-100, 1000) for _ in range(1000)]
12821276
self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
12831277

1284-
def test_start_argument(self):
1285-
# Test that the optional start argument works correctly.
1286-
data = [random.uniform(1, 1000) for _ in range(100)]
1287-
t = self.func(data)[1]
1288-
self.assertEqual(t+42, self.func(data, 42)[1])
1289-
self.assertEqual(t-23, self.func(data, -23)[1])
1290-
self.assertEqual(t+Fraction(1e20), self.func(data, 1e20)[1])
1291-
12921278
def test_strings_fail(self):
12931279
# Sum of strings should fail.
12941280
self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
@@ -2077,6 +2063,13 @@ def test_decimals(self):
20772063
self.assertEqual(result, exact)
20782064
self.assertIsInstance(result, Decimal)
20792065

2066+
def test_accuracy_bug_20499(self):
2067+
data = [0, 0, 1]
2068+
exact = 2 / 9
2069+
result = self.func(data)
2070+
self.assertEqual(result, exact)
2071+
self.assertIsInstance(result, float)
2072+
20802073

20812074
class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
20822075
# Tests for sample variance.
@@ -2117,6 +2110,13 @@ def test_center_not_at_mean(self):
21172110
self.assertEqual(self.func(data), 0.5)
21182111
self.assertEqual(self.func(data, xbar=2.0), 1.0)
21192112

2113+
def test_accuracy_bug_20499(self):
2114+
data = [0, 0, 2]
2115+
exact = 4 / 3
2116+
result = self.func(data)
2117+
self.assertEqual(result, exact)
2118+
self.assertIsInstance(result, float)
2119+
21202120
class TestPStdev(VarianceStdevMixin, NumericTestCase):
21212121
# Tests for population standard deviation.
21222122
def setUp(self):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve the speed and accuracy of statistics.pvariance().

0 commit comments

Comments
 (0)