Skip to content

Commit 68b352a

Browse files
bpo-40396: Support GenericAlias in the typing functions. (GH-19718)
1 parent cfaf4c0 commit 68b352a

File tree

3 files changed

+62
-6
lines changed

3 files changed

+62
-6
lines changed

Lib/test/test_typing.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import NamedTuple, TypedDict
2323
from typing import IO, TextIO, BinaryIO
2424
from typing import Pattern, Match
25-
from typing import Annotated
25+
from typing import Annotated, ForwardRef
2626
import abc
2727
import typing
2828
import weakref
@@ -1756,11 +1756,17 @@ def test_extended_generic_rules_repr(self):
17561756

17571757
def test_generic_forward_ref(self):
17581758
def foobar(x: List[List['CC']]): ...
1759+
def foobar2(x: list[list[ForwardRef('CC')]]): ...
17591760
class CC: ...
17601761
self.assertEqual(
17611762
get_type_hints(foobar, globals(), locals()),
17621763
{'x': List[List[CC]]}
17631764
)
1765+
self.assertEqual(
1766+
get_type_hints(foobar2, globals(), locals()),
1767+
{'x': list[list[CC]]}
1768+
)
1769+
17641770
T = TypeVar('T')
17651771
AT = Tuple[T, ...]
17661772
def barfoo(x: AT): ...
@@ -2446,6 +2452,12 @@ def foo(a: Tuple['T']):
24462452
self.assertEqual(get_type_hints(foo, globals(), locals()),
24472453
{'a': Tuple[T]})
24482454

2455+
def foo(a: tuple[ForwardRef('T')]):
2456+
pass
2457+
2458+
self.assertEqual(get_type_hints(foo, globals(), locals()),
2459+
{'a': tuple[T]})
2460+
24492461
def test_forward_recursion_actually(self):
24502462
def namespace1():
24512463
a = typing.ForwardRef('A')
@@ -2909,19 +2921,41 @@ def foobar(x: List['X']): ...
29092921
get_type_hints(foobar, globals(), locals(), include_extras=True),
29102922
{'x': List[Annotated[int, (1, 10)]]}
29112923
)
2924+
2925+
def foobar(x: list[ForwardRef('X')]): ...
2926+
X = Annotated[int, (1, 10)]
2927+
self.assertEqual(
2928+
get_type_hints(foobar, globals(), locals()),
2929+
{'x': list[int]}
2930+
)
2931+
self.assertEqual(
2932+
get_type_hints(foobar, globals(), locals(), include_extras=True),
2933+
{'x': list[Annotated[int, (1, 10)]]}
2934+
)
2935+
29122936
BA = Tuple[Annotated[T, (1, 0)], ...]
29132937
def barfoo(x: BA): ...
29142938
self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], Tuple[T, ...])
29152939
self.assertIs(
29162940
get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'],
29172941
BA
29182942
)
2943+
2944+
BA = tuple[Annotated[T, (1, 0)], ...]
2945+
def barfoo(x: BA): ...
2946+
self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], tuple[T, ...])
2947+
self.assertIs(
2948+
get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'],
2949+
BA
2950+
)
2951+
29192952
def barfoo2(x: typing.Callable[..., Annotated[List[T], "const"]],
29202953
y: typing.Union[int, Annotated[T, "mutable"]]): ...
29212954
self.assertEqual(
29222955
get_type_hints(barfoo2, globals(), locals()),
29232956
{'x': typing.Callable[..., List[T]], 'y': typing.Union[int, T]}
29242957
)
2958+
29252959
BA2 = typing.Callable[..., List[T]]
29262960
def barfoo3(x: BA2): ...
29272961
self.assertIs(
@@ -2972,6 +3006,9 @@ class C(Generic[T]): pass
29723006
self.assertIs(get_origin(Generic[T]), Generic)
29733007
self.assertIs(get_origin(List[Tuple[T, T]][int]), list)
29743008
self.assertIs(get_origin(Annotated[T, 'thing']), Annotated)
3009+
self.assertIs(get_origin(List), list)
3010+
self.assertIs(get_origin(list[int]), list)
3011+
self.assertIs(get_origin(list), None)
29753012

29763013
def test_get_args(self):
29773014
T = TypeVar('T')
@@ -2993,6 +3030,9 @@ class C(Generic[T]): pass
29933030
self.assertEqual(get_args(Tuple[int, ...]), (int, ...))
29943031
self.assertEqual(get_args(Tuple[()]), ((),))
29953032
self.assertEqual(get_args(Annotated[T, 'one', 2, ['three']]), (T, 'one', 2, ['three']))
3033+
self.assertEqual(get_args(List), (typing.T,))
3034+
self.assertEqual(get_args(list[int]), (int,))
3035+
self.assertEqual(get_args(list), ())
29963036

29973037

29983038
class CollectionsAbcTests(BaseTestCase):

Lib/typing.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _subs_tvars(tp, tvars, subs):
191191
"""Substitute type variables 'tvars' with substitutions 'subs'.
192192
These two must have the same length.
193193
"""
194-
if not isinstance(tp, _GenericAlias):
194+
if not isinstance(tp, (_GenericAlias, GenericAlias)):
195195
return tp
196196
new_args = list(tp.__args__)
197197
for a, arg in enumerate(tp.__args__):
@@ -203,7 +203,10 @@ def _subs_tvars(tp, tvars, subs):
203203
new_args[a] = _subs_tvars(arg, tvars, subs)
204204
if tp.__origin__ is Union:
205205
return Union[tuple(new_args)]
206-
return tp.copy_with(tuple(new_args))
206+
if isinstance(tp, GenericAlias):
207+
return GenericAlias(tp.__origin__, tuple(new_args))
208+
else:
209+
return tp.copy_with(tuple(new_args))
207210

208211

209212
def _check_generic(cls, parameters):
@@ -278,6 +281,11 @@ def _eval_type(t, globalns, localns):
278281
res = t.copy_with(ev_args)
279282
res._special = t._special
280283
return res
284+
if isinstance(t, GenericAlias):
285+
ev_args = tuple(_eval_type(a, globalns, localns) for a in t.__args__)
286+
if ev_args == t.__args__:
287+
return t
288+
return GenericAlias(t.__origin__, ev_args)
281289
return t
282290

283291

@@ -1368,6 +1376,11 @@ def _strip_annotations(t):
13681376
res = t.copy_with(stripped_args)
13691377
res._special = t._special
13701378
return res
1379+
if isinstance(t, GenericAlias):
1380+
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
1381+
if stripped_args == t.__args__:
1382+
return t
1383+
return GenericAlias(t.__origin__, stripped_args)
13711384
return t
13721385

13731386

@@ -1387,7 +1400,7 @@ def get_origin(tp):
13871400
"""
13881401
if isinstance(tp, _AnnotatedAlias):
13891402
return Annotated
1390-
if isinstance(tp, _GenericAlias):
1403+
if isinstance(tp, (_GenericAlias, GenericAlias)):
13911404
return tp.__origin__
13921405
if tp is Generic:
13931406
return Generic
@@ -1407,9 +1420,9 @@ def get_args(tp):
14071420
"""
14081421
if isinstance(tp, _AnnotatedAlias):
14091422
return (tp.__origin__,) + tp.__metadata__
1410-
if isinstance(tp, _GenericAlias):
1423+
if isinstance(tp, (_GenericAlias, GenericAlias)):
14111424
res = tp.__args__
1412-
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis:
1425+
if tp.__origin__ is collections.abc.Callable and res[0] is not Ellipsis:
14131426
res = (list(res[:-1]), res[-1])
14141427
return res
14151428
return ()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Functions :func:`typing.get_origin`, :func:`typing.get_args` and
2+
:func:`typing.get_type_hints` support now generic aliases like
3+
``list[int]``.

0 commit comments

Comments
 (0)