Skip to content

Commit 8e3c953

Browse files
vstinnermdickinson
andauthored
gh-73468: Add math.fma() function (#116667)
Added new math.fma() function, wrapping C99's ``fma()`` operation: fused multiply-add function. Co-authored-by: Mark Dickinson <[email protected]>
1 parent b8d808d commit 8e3c953

File tree

6 files changed

+371
-1
lines changed

6 files changed

+371
-1
lines changed

Doc/library/math.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,22 @@ Number-theoretic and representation functions
8282
should return an :class:`~numbers.Integral` value.
8383

8484

85+
.. function:: fma(x, y, z)
86+
87+
Fused multiply-add operation. Return ``(x * y) + z``, computed as though with
88+
infinite precision and range followed by a single round to the ``float``
89+
format. This operation often provides better accuracy than the direct
90+
expression ``(x * y) + z``.
91+
92+
This function follows the specification of the fusedMultiplyAdd operation
93+
described in the IEEE 754 standard. The standard leaves one case
94+
implementation-defined, namely the result of ``fma(0, inf, nan)``
95+
and ``fma(inf, 0, nan)``. In these cases, ``math.fma`` returns a NaN,
96+
and does not raise any exception.
97+
98+
.. versionadded:: 3.13
99+
100+
85101
.. function:: fmod(x, y)
86102

87103
Return ``fmod(x, y)``, as defined by the platform C library. Note that the

Doc/whatsnew/3.13.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,16 @@ marshal
383383
code objects which are incompatible between Python versions.
384384
(Contributed by Serhiy Storchaka in :gh:`113626`.)
385385

386+
math
387+
----
388+
389+
A new function :func:`~math.fma` for fused multiply-add operations has been
390+
added. This function computes ``x * y + z`` with only a single round, and so
391+
avoids any intermediate loss of precision. It wraps the ``fma()`` function
392+
provided by C99, and follows the specification of the IEEE 754
393+
"fusedMultiplyAdd" operation for special cases.
394+
(Contributed by Mark Dickinson and Victor Stinner in :gh:`73468`.)
395+
386396
mmap
387397
----
388398

Lib/test/test_math.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2613,6 +2613,244 @@ def test_fractions(self):
26132613
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
26142614

26152615

2616+
class FMATests(unittest.TestCase):
2617+
""" Tests for math.fma. """
2618+
2619+
def test_fma_nan_results(self):
2620+
# Selected representative values.
2621+
values = [
2622+
-math.inf, -1e300, -2.3, -1e-300, -0.0,
2623+
0.0, 1e-300, 2.3, 1e300, math.inf, math.nan
2624+
]
2625+
2626+
# If any input is a NaN, the result should be a NaN, too.
2627+
for a, b in itertools.product(values, repeat=2):
2628+
self.assertIsNaN(math.fma(math.nan, a, b))
2629+
self.assertIsNaN(math.fma(a, math.nan, b))
2630+
self.assertIsNaN(math.fma(a, b, math.nan))
2631+
2632+
def test_fma_infinities(self):
2633+
# Cases involving infinite inputs or results.
2634+
positives = [1e-300, 2.3, 1e300, math.inf]
2635+
finites = [-1e300, -2.3, -1e-300, -0.0, 0.0, 1e-300, 2.3, 1e300]
2636+
non_nans = [-math.inf, -2.3, -0.0, 0.0, 2.3, math.inf]
2637+
2638+
# ValueError due to inf * 0 computation.
2639+
for c in non_nans:
2640+
for infinity in [math.inf, -math.inf]:
2641+
for zero in [0.0, -0.0]:
2642+
with self.assertRaises(ValueError):
2643+
math.fma(infinity, zero, c)
2644+
with self.assertRaises(ValueError):
2645+
math.fma(zero, infinity, c)
2646+
2647+
# ValueError when a*b and c both infinite of opposite signs.
2648+
for b in positives:
2649+
with self.assertRaises(ValueError):
2650+
math.fma(math.inf, b, -math.inf)
2651+
with self.assertRaises(ValueError):
2652+
math.fma(math.inf, -b, math.inf)
2653+
with self.assertRaises(ValueError):
2654+
math.fma(-math.inf, -b, -math.inf)
2655+
with self.assertRaises(ValueError):
2656+
math.fma(-math.inf, b, math.inf)
2657+
with self.assertRaises(ValueError):
2658+
math.fma(b, math.inf, -math.inf)
2659+
with self.assertRaises(ValueError):
2660+
math.fma(-b, math.inf, math.inf)
2661+
with self.assertRaises(ValueError):
2662+
math.fma(-b, -math.inf, -math.inf)
2663+
with self.assertRaises(ValueError):
2664+
math.fma(b, -math.inf, math.inf)
2665+
2666+
# Infinite result when a*b and c both infinite of the same sign.
2667+
for b in positives:
2668+
self.assertEqual(math.fma(math.inf, b, math.inf), math.inf)
2669+
self.assertEqual(math.fma(math.inf, -b, -math.inf), -math.inf)
2670+
self.assertEqual(math.fma(-math.inf, -b, math.inf), math.inf)
2671+
self.assertEqual(math.fma(-math.inf, b, -math.inf), -math.inf)
2672+
self.assertEqual(math.fma(b, math.inf, math.inf), math.inf)
2673+
self.assertEqual(math.fma(-b, math.inf, -math.inf), -math.inf)
2674+
self.assertEqual(math.fma(-b, -math.inf, math.inf), math.inf)
2675+
self.assertEqual(math.fma(b, -math.inf, -math.inf), -math.inf)
2676+
2677+
# Infinite result when a*b finite, c infinite.
2678+
for a, b in itertools.product(finites, finites):
2679+
self.assertEqual(math.fma(a, b, math.inf), math.inf)
2680+
self.assertEqual(math.fma(a, b, -math.inf), -math.inf)
2681+
2682+
# Infinite result when a*b infinite, c finite.
2683+
for b, c in itertools.product(positives, finites):
2684+
self.assertEqual(math.fma(math.inf, b, c), math.inf)
2685+
self.assertEqual(math.fma(-math.inf, b, c), -math.inf)
2686+
self.assertEqual(math.fma(-math.inf, -b, c), math.inf)
2687+
self.assertEqual(math.fma(math.inf, -b, c), -math.inf)
2688+
2689+
self.assertEqual(math.fma(b, math.inf, c), math.inf)
2690+
self.assertEqual(math.fma(b, -math.inf, c), -math.inf)
2691+
self.assertEqual(math.fma(-b, -math.inf, c), math.inf)
2692+
self.assertEqual(math.fma(-b, math.inf, c), -math.inf)
2693+
2694+
# gh-73468: On WASI and FreeBSD, libc fma() doesn't implement IEE 754-2008
2695+
# properly: it doesn't use the right sign when the result is zero.
2696+
@unittest.skipIf(support.is_wasi,
2697+
"WASI fma() doesn't implement IEE 754-2008 properly")
2698+
@unittest.skipIf(sys.platform.startswith('freebsd'),
2699+
"FreeBSD fma() doesn't implement IEE 754-2008 properly")
2700+
def test_fma_zero_result(self):
2701+
nonnegative_finites = [0.0, 1e-300, 2.3, 1e300]
2702+
2703+
# Zero results from exact zero inputs.
2704+
for b in nonnegative_finites:
2705+
self.assertIsPositiveZero(math.fma(0.0, b, 0.0))
2706+
self.assertIsPositiveZero(math.fma(0.0, b, -0.0))
2707+
self.assertIsNegativeZero(math.fma(0.0, -b, -0.0))
2708+
self.assertIsPositiveZero(math.fma(0.0, -b, 0.0))
2709+
self.assertIsPositiveZero(math.fma(-0.0, -b, 0.0))
2710+
self.assertIsPositiveZero(math.fma(-0.0, -b, -0.0))
2711+
self.assertIsNegativeZero(math.fma(-0.0, b, -0.0))
2712+
self.assertIsPositiveZero(math.fma(-0.0, b, 0.0))
2713+
2714+
self.assertIsPositiveZero(math.fma(b, 0.0, 0.0))
2715+
self.assertIsPositiveZero(math.fma(b, 0.0, -0.0))
2716+
self.assertIsNegativeZero(math.fma(-b, 0.0, -0.0))
2717+
self.assertIsPositiveZero(math.fma(-b, 0.0, 0.0))
2718+
self.assertIsPositiveZero(math.fma(-b, -0.0, 0.0))
2719+
self.assertIsPositiveZero(math.fma(-b, -0.0, -0.0))
2720+
self.assertIsNegativeZero(math.fma(b, -0.0, -0.0))
2721+
self.assertIsPositiveZero(math.fma(b, -0.0, 0.0))
2722+
2723+
# Exact zero result from nonzero inputs.
2724+
self.assertIsPositiveZero(math.fma(2.0, 2.0, -4.0))
2725+
self.assertIsPositiveZero(math.fma(2.0, -2.0, 4.0))
2726+
self.assertIsPositiveZero(math.fma(-2.0, -2.0, -4.0))
2727+
self.assertIsPositiveZero(math.fma(-2.0, 2.0, 4.0))
2728+
2729+
# Underflow to zero.
2730+
tiny = 1e-300
2731+
self.assertIsPositiveZero(math.fma(tiny, tiny, 0.0))
2732+
self.assertIsNegativeZero(math.fma(tiny, -tiny, 0.0))
2733+
self.assertIsPositiveZero(math.fma(-tiny, -tiny, 0.0))
2734+
self.assertIsNegativeZero(math.fma(-tiny, tiny, 0.0))
2735+
self.assertIsPositiveZero(math.fma(tiny, tiny, -0.0))
2736+
self.assertIsNegativeZero(math.fma(tiny, -tiny, -0.0))
2737+
self.assertIsPositiveZero(math.fma(-tiny, -tiny, -0.0))
2738+
self.assertIsNegativeZero(math.fma(-tiny, tiny, -0.0))
2739+
2740+
# Corner case where rounding the multiplication would
2741+
# give the wrong result.
2742+
x = float.fromhex('0x1p-500')
2743+
y = float.fromhex('0x1p-550')
2744+
z = float.fromhex('0x1p-1000')
2745+
self.assertIsNegativeZero(math.fma(x-y, x+y, -z))
2746+
self.assertIsPositiveZero(math.fma(y-x, x+y, z))
2747+
self.assertIsNegativeZero(math.fma(y-x, -(x+y), -z))
2748+
self.assertIsPositiveZero(math.fma(x-y, -(x+y), z))
2749+
2750+
def test_fma_overflow(self):
2751+
a = b = float.fromhex('0x1p512')
2752+
c = float.fromhex('0x1p1023')
2753+
# Overflow from multiplication.
2754+
with self.assertRaises(OverflowError):
2755+
math.fma(a, b, 0.0)
2756+
self.assertEqual(math.fma(a, b/2.0, 0.0), c)
2757+
# Overflow from the addition.
2758+
with self.assertRaises(OverflowError):
2759+
math.fma(a, b/2.0, c)
2760+
# No overflow, even though a*b overflows a float.
2761+
self.assertEqual(math.fma(a, b, -c), c)
2762+
2763+
# Extreme case: a * b is exactly at the overflow boundary, so the
2764+
# tiniest offset makes a difference between overflow and a finite
2765+
# result.
2766+
a = float.fromhex('0x1.ffffffc000000p+511')
2767+
b = float.fromhex('0x1.0000002000000p+512')
2768+
c = float.fromhex('0x0.0000000000001p-1022')
2769+
with self.assertRaises(OverflowError):
2770+
math.fma(a, b, 0.0)
2771+
with self.assertRaises(OverflowError):
2772+
math.fma(a, b, c)
2773+
self.assertEqual(math.fma(a, b, -c),
2774+
float.fromhex('0x1.fffffffffffffp+1023'))
2775+
2776+
# Another extreme case: here a*b is about as large as possible subject
2777+
# to math.fma(a, b, c) being finite.
2778+
a = float.fromhex('0x1.ae565943785f9p+512')
2779+
b = float.fromhex('0x1.3094665de9db8p+512')
2780+
c = float.fromhex('0x1.fffffffffffffp+1023')
2781+
self.assertEqual(math.fma(a, b, -c), c)
2782+
2783+
def test_fma_single_round(self):
2784+
a = float.fromhex('0x1p-50')
2785+
self.assertEqual(math.fma(a - 1.0, a + 1.0, 1.0), a*a)
2786+
2787+
def test_random(self):
2788+
# A collection of randomly generated inputs for which the naive FMA
2789+
# (with two rounds) gives a different result from a singly-rounded FMA.
2790+
2791+
# tuples (a, b, c, expected)
2792+
test_values = [
2793+
('0x1.694adde428b44p-1', '0x1.371b0d64caed7p-1',
2794+
'0x1.f347e7b8deab8p-4', '0x1.19f10da56c8adp-1'),
2795+
('0x1.605401ccc6ad6p-2', '0x1.ce3a40bf56640p-2',
2796+
'0x1.96e3bf7bf2e20p-2', '0x1.1af6d8aa83101p-1'),
2797+
('0x1.e5abd653a67d4p-2', '0x1.a2e400209b3e6p-1',
2798+
'0x1.a90051422ce13p-1', '0x1.37d68cc8c0fbbp+0'),
2799+
('0x1.f94e8efd54700p-2', '0x1.123065c812cebp-1',
2800+
'0x1.458f86fb6ccd0p-1', '0x1.ccdcee26a3ff3p-1'),
2801+
('0x1.bd926f1eedc96p-1', '0x1.eee9ca68c5740p-1',
2802+
'0x1.960c703eb3298p-2', '0x1.3cdcfb4fdb007p+0'),
2803+
('0x1.27348350fbccdp-1', '0x1.3b073914a53f1p-1',
2804+
'0x1.e300da5c2b4cbp-1', '0x1.4c51e9a3c4e29p+0'),
2805+
('0x1.2774f00b3497bp-1', '0x1.7038ec336bff0p-2',
2806+
'0x1.2f6f2ccc3576bp-1', '0x1.99ad9f9c2688bp-1'),
2807+
('0x1.51d5a99300e5cp-1', '0x1.5cd74abd445a1p-1',
2808+
'0x1.8880ab0bbe530p-1', '0x1.3756f96b91129p+0'),
2809+
('0x1.73cb965b821b8p-2', '0x1.218fd3d8d5371p-1',
2810+
'0x1.d1ea966a1f758p-2', '0x1.5217b8fd90119p-1'),
2811+
('0x1.4aa98e890b046p-1', '0x1.954d85dff1041p-1',
2812+
'0x1.122b59317ebdfp-1', '0x1.0bf644b340cc5p+0'),
2813+
('0x1.e28f29e44750fp-1', '0x1.4bcc4fdcd18fep-1',
2814+
'0x1.fd47f81298259p-1', '0x1.9b000afbc9995p+0'),
2815+
('0x1.d2e850717fe78p-3', '0x1.1dd7531c303afp-1',
2816+
'0x1.e0869746a2fc2p-2', '0x1.316df6eb26439p-1'),
2817+
('0x1.cf89c75ee6fbap-2', '0x1.b23decdc66825p-1',
2818+
'0x1.3d1fe76ac6168p-1', '0x1.00d8ea4c12abbp+0'),
2819+
('0x1.3265ae6f05572p-2', '0x1.16d7ec285f7a2p-1',
2820+
'0x1.0b8405b3827fbp-1', '0x1.5ef33c118a001p-1'),
2821+
('0x1.c4d1bf55ec1a5p-1', '0x1.bc59618459e12p-2',
2822+
'0x1.ce5b73dc1773dp-1', '0x1.496cf6164f99bp+0'),
2823+
('0x1.d350026ac3946p-1', '0x1.9a234e149a68cp-2',
2824+
'0x1.f5467b1911fd6p-2', '0x1.b5cee3225caa5p-1'),
2825+
]
2826+
for a_hex, b_hex, c_hex, expected_hex in test_values:
2827+
a = float.fromhex(a_hex)
2828+
b = float.fromhex(b_hex)
2829+
c = float.fromhex(c_hex)
2830+
expected = float.fromhex(expected_hex)
2831+
self.assertEqual(math.fma(a, b, c), expected)
2832+
self.assertEqual(math.fma(b, a, c), expected)
2833+
2834+
# Custom assertions.
2835+
def assertIsNaN(self, value):
2836+
self.assertTrue(
2837+
math.isnan(value),
2838+
msg="Expected a NaN, got {!r}".format(value)
2839+
)
2840+
2841+
def assertIsPositiveZero(self, value):
2842+
self.assertTrue(
2843+
value == 0 and math.copysign(1, value) > 0,
2844+
msg="Expected a positive zero, got {!r}".format(value)
2845+
)
2846+
2847+
def assertIsNegativeZero(self, value):
2848+
self.assertTrue(
2849+
value == 0 and math.copysign(1, value) < 0,
2850+
msg="Expected a negative zero, got {!r}".format(value)
2851+
)
2852+
2853+
26162854
def load_tests(loader, tests, pattern):
26172855
from doctest import DocFileSuite
26182856
tests.addTest(DocFileSuite(os.path.join("mathdata", "ieee754.txt")))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added new :func:`math.fma` function, wrapping C99's ``fma()`` operation:
2+
fused multiply-add function. Patch by Mark Dickinson and Victor Stinner.

Modules/clinic/mathmodule.c.h

Lines changed: 62 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)