Skip to content

Commit 055760e

Browse files
JelleZijlstraAlexWaygoodgvanrossumFidget-Spinner
authored
gh-89263: Add typing.get_overloads (GH-31716)
Based on suggestions by Guido van Rossum, Spencer Brown, and Alex Waygood. Co-authored-by: Alex Waygood <[email protected]> Co-authored-by: Guido van Rossum <[email protected]> Co-authored-by: Ken Jin <[email protected]>
1 parent 9300b6d commit 055760e

File tree

4 files changed

+133
-4
lines changed

4 files changed

+133
-4
lines changed

Doc/library/typing.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2407,6 +2407,35 @@ Functions and decorators
24072407

24082408
See :pep:`484` for details and comparison with other typing semantics.
24092409

2410+
.. versionchanged:: 3.11
2411+
Overloaded functions can now be introspected at runtime using
2412+
:func:`get_overloads`.
2413+
2414+
2415+
.. function:: get_overloads(func)
2416+
2417+
Return a sequence of :func:`@overload <overload>`-decorated definitions for
2418+
*func*. *func* is the function object for the implementation of the
2419+
overloaded function. For example, given the definition of ``process`` in
2420+
the documentation for :func:`@overload <overload>`,
2421+
``get_overloads(process)`` will return a sequence of three function objects
2422+
for the three defined overloads. If called on a function with no overloads,
2423+
``get_overloads`` returns an empty sequence.
2424+
2425+
``get_overloads`` can be used for introspecting an overloaded function at
2426+
runtime.
2427+
2428+
.. versionadded:: 3.11
2429+
2430+
2431+
.. function:: clear_overloads()
2432+
2433+
Clear all registered overloads in the internal registry. This can be used
2434+
to reclaim the memory used by the registry.
2435+
2436+
.. versionadded:: 3.11
2437+
2438+
24102439
.. decorator:: final
24112440

24122441
A decorator to indicate to type checkers that the decorated method

Lib/test/test_typing.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import contextlib
22
import collections
3+
from collections import defaultdict
34
from functools import lru_cache
45
import inspect
56
import pickle
67
import re
78
import sys
89
import warnings
910
from unittest import TestCase, main, skipUnless, skip
11+
from unittest.mock import patch
1012
from copy import copy, deepcopy
1113

1214
from typing import Any, NoReturn, Never, assert_never
15+
from typing import overload, get_overloads, clear_overloads
1316
from typing import TypeVar, TypeVarTuple, Unpack, AnyStr
1417
from typing import T, KT, VT # Not in __all__.
1518
from typing import Union, Optional, Literal
@@ -3890,11 +3893,22 @@ def test_or(self):
38903893
self.assertEqual("x" | X, Union["x", X])
38913894

38923895

3896+
@lru_cache()
3897+
def cached_func(x, y):
3898+
return 3 * x + y
3899+
3900+
3901+
class MethodHolder:
3902+
@classmethod
3903+
def clsmethod(cls): ...
3904+
@staticmethod
3905+
def stmethod(): ...
3906+
def method(self): ...
3907+
3908+
38933909
class OverloadTests(BaseTestCase):
38943910

38953911
def test_overload_fails(self):
3896-
from typing import overload
3897-
38983912
with self.assertRaises(RuntimeError):
38993913

39003914
@overload
@@ -3904,8 +3918,6 @@ def blah():
39043918
blah()
39053919

39063920
def test_overload_succeeds(self):
3907-
from typing import overload
3908-
39093921
@overload
39103922
def blah():
39113923
pass
@@ -3915,6 +3927,58 @@ def blah():
39153927

39163928
blah()
39173929

3930+
def set_up_overloads(self):
3931+
def blah():
3932+
pass
3933+
3934+
overload1 = blah
3935+
overload(blah)
3936+
3937+
def blah():
3938+
pass
3939+
3940+
overload2 = blah
3941+
overload(blah)
3942+
3943+
def blah():
3944+
pass
3945+
3946+
return blah, [overload1, overload2]
3947+
3948+
# Make sure we don't clear the global overload registry
3949+
@patch("typing._overload_registry",
3950+
defaultdict(lambda: defaultdict(dict)))
3951+
def test_overload_registry(self):
3952+
# The registry starts out empty
3953+
self.assertEqual(typing._overload_registry, {})
3954+
3955+
impl, overloads = self.set_up_overloads()
3956+
self.assertNotEqual(typing._overload_registry, {})
3957+
self.assertEqual(list(get_overloads(impl)), overloads)
3958+
3959+
def some_other_func(): pass
3960+
overload(some_other_func)
3961+
other_overload = some_other_func
3962+
def some_other_func(): pass
3963+
self.assertEqual(list(get_overloads(some_other_func)), [other_overload])
3964+
3965+
# Make sure that after we clear all overloads, the registry is
3966+
# completely empty.
3967+
clear_overloads()
3968+
self.assertEqual(typing._overload_registry, {})
3969+
self.assertEqual(get_overloads(impl), [])
3970+
3971+
# Querying a function with no overloads shouldn't change the registry.
3972+
def the_only_one(): pass
3973+
self.assertEqual(get_overloads(the_only_one), [])
3974+
self.assertEqual(typing._overload_registry, {})
3975+
3976+
def test_overload_registry_repeated(self):
3977+
for _ in range(2):
3978+
impl, overloads = self.set_up_overloads()
3979+
3980+
self.assertEqual(list(get_overloads(impl)), overloads)
3981+
39183982

39193983
# Definitions needed for features introduced in Python 3.6
39203984

Lib/typing.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from abc import abstractmethod, ABCMeta
2323
import collections
24+
from collections import defaultdict
2425
import collections.abc
2526
import contextlib
2627
import functools
@@ -121,9 +122,11 @@ def _idfunc(_, x):
121122
'assert_type',
122123
'assert_never',
123124
'cast',
125+
'clear_overloads',
124126
'final',
125127
'get_args',
126128
'get_origin',
129+
'get_overloads',
127130
'get_type_hints',
128131
'is_typeddict',
129132
'LiteralString',
@@ -2450,6 +2453,10 @@ def _overload_dummy(*args, **kwds):
24502453
"by an implementation that is not @overload-ed.")
24512454

24522455

2456+
# {module: {qualname: {firstlineno: func}}}
2457+
_overload_registry = defaultdict(functools.partial(defaultdict, dict))
2458+
2459+
24532460
def overload(func):
24542461
"""Decorator for overloaded functions/methods.
24552462
@@ -2475,10 +2482,37 @@ def utf8(value: bytes) -> bytes: ...
24752482
def utf8(value: str) -> bytes: ...
24762483
def utf8(value):
24772484
# implementation goes here
2485+
2486+
The overloads for a function can be retrieved at runtime using the
2487+
get_overloads() function.
24782488
"""
2489+
# classmethod and staticmethod
2490+
f = getattr(func, "__func__", func)
2491+
try:
2492+
_overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func
2493+
except AttributeError:
2494+
# Not a normal function; ignore.
2495+
pass
24792496
return _overload_dummy
24802497

24812498

2499+
def get_overloads(func):
2500+
"""Return all defined overloads for *func* as a sequence."""
2501+
# classmethod and staticmethod
2502+
f = getattr(func, "__func__", func)
2503+
if f.__module__ not in _overload_registry:
2504+
return []
2505+
mod_dict = _overload_registry[f.__module__]
2506+
if f.__qualname__ not in mod_dict:
2507+
return []
2508+
return list(mod_dict[f.__qualname__].values())
2509+
2510+
2511+
def clear_overloads():
2512+
"""Clear all overloads in the registry."""
2513+
_overload_registry.clear()
2514+
2515+
24822516
def final(f):
24832517
"""A decorator to indicate final methods and final classes.
24842518
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Add :func:`typing.get_overloads` and :func:`typing.clear_overloads`.
2+
Patch by Jelle Zijlstra.

0 commit comments

Comments
 (0)