Skip to content

Commit 037b5fe

Browse files
committed
Correctly rounded stdev results for Decimal inputs
1 parent ba248b7 commit 037b5fe

File tree

2 files changed

+66
-4
lines changed

2 files changed

+66
-4
lines changed

Lib/statistics.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,18 @@ def _fail_neg(values, errmsg='negative value'):
305305
raise StatisticsError(errmsg)
306306
yield x
307307

308+
308309
def _isqrt_frac_rto(n: int, m: int) -> float:
309310
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
310311
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
311312
a = math.isqrt(n // m)
312313
return a | (a*a*m != n)
313314

315+
314316
# For 53 bit precision floats, the _sqrt_frac() shift is 109.
315317
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
316318

319+
317320
def _sqrt_frac(n: int, m: int) -> float:
318321
"""Square root of n/m as a float, correctly rounded."""
319322
# See principle and proof sketch at: https://bugs.python.org/msg407078
@@ -327,6 +330,31 @@ def _sqrt_frac(n: int, m: int) -> float:
327330
return numerator / denominator # Convert to float
328331

329332

333+
def _deci_sqrt(n: int, m: int) -> Decimal:
334+
"""Square root of n/m as a float, correctly rounded."""
335+
# Premise: For decimal, computing sqrt(m / n) can be off by 1 ulp.
336+
# Method: Check the result, moving up or down a step if needed.
337+
if not n:
338+
return 0.0
339+
340+
f_square = Fraction(n, m)
341+
342+
d_mid = (Decimal(n) / Decimal(m)).sqrt()
343+
f_mid = Fraction(*d_mid.as_integer_ratio())
344+
345+
d_plus = d_mid.next_plus()
346+
f_plus = Fraction(*d_plus.as_integer_ratio())
347+
if f_square > ((f_mid + f_plus) / 2) ** 2:
348+
return d_plus
349+
350+
d_minus = d_mid.next_minus()
351+
f_minus = Fraction(*d_minus.as_integer_ratio())
352+
if f_square < ((f_mid + f_minus) / 2) ** 2:
353+
return d_minus
354+
355+
return d_mid
356+
357+
330358
# === Measures of central tendency (averages) ===
331359

332360
def mean(data):
@@ -888,9 +916,8 @@ def pstdev(data, mu=None):
888916
raise StatisticsError('pstdev requires at least one data point')
889917
T, ss = _ss(data, mu)
890918
mss = ss / n
891-
if hasattr(T, 'sqrt'):
892-
var = _convert(mss, T)
893-
return var.sqrt()
919+
if issubclass(T, Decimal):
920+
return _deci_sqrt(mss.numerator, mss.denominator)
894921
return _sqrt_frac(mss.numerator, mss.denominator)
895922

896923

Lib/test/test_statistics.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2219,7 +2219,42 @@ def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
22192219
statistics._sqrt_frac(1, 0)
22202220

22212221
# The result is well defined if both inputs are negative
2222-
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
2222+
self.assertEqual(statistics._sqrt_frac(-2, -1), statistics._sqrt_frac(2, 1))
2223+
2224+
def test_deci_sqrt(self):
2225+
root: Decimal
2226+
numerator: int
2227+
denominator: int
2228+
2229+
for root, numerator, denominator in [
2230+
(Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj
2231+
(Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up
2232+
(Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down
2233+
]:
2234+
with decimal.localcontext(decimal.DefaultContext):
2235+
self.assertEqual(statistics._deci_sqrt(numerator, denominator), root)
2236+
2237+
# Confirm expected root with a quad precision decimal computation
2238+
with decimal.localcontext(decimal.DefaultContext) as ctx:
2239+
ctx.prec *= 4
2240+
high_prec_root = (Decimal(numerator) / Decimal(denominator)).sqrt()
2241+
with decimal.localcontext(decimal.DefaultContext):
2242+
target_root = +high_prec_root
2243+
self.assertEqual(root, target_root)
2244+
2245+
# Verify that corner cases and error handling match Decimal.sqrt()
2246+
self.assertEqual(statistics._deci_sqrt(0, 1), 0.0)
2247+
with self.assertRaises(decimal.InvalidOperation):
2248+
statistics._deci_sqrt(-1, 1)
2249+
with self.assertRaises(decimal.InvalidOperation):
2250+
statistics._deci_sqrt(1, -1)
2251+
2252+
# Error handling for zero denominator matches that for Fraction(1, 0)
2253+
with self.assertRaises(ZeroDivisionError):
2254+
statistics._deci_sqrt(1, 0)
2255+
2256+
# The result is well defined if both inputs are negative
2257+
self.assertEqual(statistics._deci_sqrt(-2, -1), statistics._deci_sqrt(2, 1))
22232258

22242259

22252260
class TestStdev(VarianceStdevMixin, NumericTestCase):

0 commit comments

Comments
 (0)