Skip to content

bpo-40396: Support GenericAlias in the typing functions. #19718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import NamedTuple, TypedDict
from typing import IO, TextIO, BinaryIO
from typing import Pattern, Match
from typing import Annotated
from typing import Annotated, ForwardRef
import abc
import typing
import weakref
Expand Down Expand Up @@ -1756,11 +1756,17 @@ def test_extended_generic_rules_repr(self):

def test_generic_forward_ref(self):
def foobar(x: List[List['CC']]): ...
def foobar2(x: list[list[ForwardRef('CC')]]): ...
class CC: ...
self.assertEqual(
get_type_hints(foobar, globals(), locals()),
{'x': List[List[CC]]}
)
self.assertEqual(
get_type_hints(foobar2, globals(), locals()),
{'x': list[list[CC]]}
)

T = TypeVar('T')
AT = Tuple[T, ...]
def barfoo(x: AT): ...
Expand Down Expand Up @@ -2446,6 +2452,12 @@ def foo(a: Tuple['T']):
self.assertEqual(get_type_hints(foo, globals(), locals()),
{'a': Tuple[T]})

def foo(a: tuple[ForwardRef('T')]):
pass

self.assertEqual(get_type_hints(foo, globals(), locals()),
{'a': tuple[T]})

def test_forward_recursion_actually(self):
def namespace1():
a = typing.ForwardRef('A')
Expand Down Expand Up @@ -2909,19 +2921,41 @@ def foobar(x: List['X']): ...
get_type_hints(foobar, globals(), locals(), include_extras=True),
{'x': List[Annotated[int, (1, 10)]]}
)

def foobar(x: list[ForwardRef('X')]): ...
X = Annotated[int, (1, 10)]
self.assertEqual(
get_type_hints(foobar, globals(), locals()),
{'x': list[int]}
)
self.assertEqual(
get_type_hints(foobar, globals(), locals(), include_extras=True),
{'x': list[Annotated[int, (1, 10)]]}
)

BA = Tuple[Annotated[T, (1, 0)], ...]
def barfoo(x: BA): ...
self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], Tuple[T, ...])
self.assertIs(
get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'],
BA
)

BA = tuple[Annotated[T, (1, 0)], ...]
def barfoo(x: BA): ...
self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], tuple[T, ...])
self.assertIs(
get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'],
BA
)

def barfoo2(x: typing.Callable[..., Annotated[List[T], "const"]],
y: typing.Union[int, Annotated[T, "mutable"]]): ...
self.assertEqual(
get_type_hints(barfoo2, globals(), locals()),
{'x': typing.Callable[..., List[T]], 'y': typing.Union[int, T]}
)

BA2 = typing.Callable[..., List[T]]
def barfoo3(x: BA2): ...
self.assertIs(
Expand Down Expand Up @@ -2972,6 +3006,9 @@ class C(Generic[T]): pass
self.assertIs(get_origin(Generic[T]), Generic)
self.assertIs(get_origin(List[Tuple[T, T]][int]), list)
self.assertIs(get_origin(Annotated[T, 'thing']), Annotated)
self.assertIs(get_origin(List), list)
self.assertIs(get_origin(list[int]), list)
self.assertIs(get_origin(list), None)

def test_get_args(self):
T = TypeVar('T')
Expand All @@ -2993,6 +3030,9 @@ class C(Generic[T]): pass
self.assertEqual(get_args(Tuple[int, ...]), (int, ...))
self.assertEqual(get_args(Tuple[()]), ((),))
self.assertEqual(get_args(Annotated[T, 'one', 2, ['three']]), (T, 'one', 2, ['three']))
self.assertEqual(get_args(List), (typing.T,))
self.assertEqual(get_args(list[int]), (int,))
self.assertEqual(get_args(list), ())


class CollectionsAbcTests(BaseTestCase):
Expand Down
23 changes: 18 additions & 5 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _subs_tvars(tp, tvars, subs):
"""Substitute type variables 'tvars' with substitutions 'subs'.
These two must have the same length.
"""
if not isinstance(tp, _GenericAlias):
if not isinstance(tp, (_GenericAlias, GenericAlias)):
return tp
new_args = list(tp.__args__)
for a, arg in enumerate(tp.__args__):
Expand All @@ -203,7 +203,10 @@ def _subs_tvars(tp, tvars, subs):
new_args[a] = _subs_tvars(arg, tvars, subs)
if tp.__origin__ is Union:
return Union[tuple(new_args)]
return tp.copy_with(tuple(new_args))
if isinstance(tp, GenericAlias):
return GenericAlias(tp.__origin__, tuple(new_args))
else:
return tp.copy_with(tuple(new_args))


def _check_generic(cls, parameters):
Expand Down Expand Up @@ -278,6 +281,11 @@ def _eval_type(t, globalns, localns):
res = t.copy_with(ev_args)
res._special = t._special
return res
if isinstance(t, GenericAlias):
ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__)
if ev_args == t.__args__:
return t
return GenericAlias(t.__origin__, ev_args)
return t


Expand Down Expand Up @@ -1368,6 +1376,11 @@ def _strip_annotations(t):
res = t.copy_with(stripped_args)
res._special = t._special
return res
if isinstance(t, GenericAlias):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
return GenericAlias(t.__origin__, stripped_args)
return t


Expand All @@ -1387,7 +1400,7 @@ def get_origin(tp):
"""
if isinstance(tp, _AnnotatedAlias):
return Annotated
if isinstance(tp, _GenericAlias):
if isinstance(tp, (_GenericAlias, GenericAlias)):
return tp.__origin__
if tp is Generic:
return Generic
Expand All @@ -1407,9 +1420,9 @@ def get_args(tp):
"""
if isinstance(tp, _AnnotatedAlias):
return (tp.__origin__,) + tp.__metadata__
if isinstance(tp, _GenericAlias):
if isinstance(tp, (_GenericAlias, GenericAlias)):
res = tp.__args__
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis:
if tp.__origin__ is collections.abc.Callable and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
return res
return ()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Functions :func:`typing.get_origin`, :func:`typing.get_args` and
:func:`typing.get_type_hints` support now generic aliases like
``list[int]``.