Skip to content

Commit 1f98818

Browse files
Add __orig_bases__ to all TypedDict and NamedTuple (#150)
Co-authored-by: AlexWaygood <[email protected]>
1 parent 0b8de38 commit 1f98818

File tree

3 files changed

+101
-24
lines changed

3 files changed

+101
-24
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@
4949
Patch by Alex Waygood.
5050
- Speedup `isinstance(3, typing_extensions.SupportsIndex)` by >10x on Python
5151
<3.12. Patch by Alex Waygood.
52+
- Add `__orig_bases__` to non-generic TypedDicts, call-based TypedDicts, and
53+
call-based NamedTuples. Other TypedDicts and NamedTuples already had the attribute.
54+
Patch by Adrian Garcia Badaracco.
55+
- Constructing a call-based `TypedDict` using keyword arguments for the fields
56+
now causes a `DeprecationWarning` to be emitted. This matches the behaviour
57+
of `typing.TypedDict` on 3.11 and 3.12.
5258

5359
# Release 4.5.0 (February 14, 2023)
5460

src/test_typing_extensions.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,6 @@ def assertNotIsSubclass(self, cls, class_or_tuple, msg=None):
172172
message += f' : {msg}'
173173
raise self.failureException(message)
174174

175-
@contextlib.contextmanager
176-
def assertWarnsIf(self, condition: bool, expected_warning: Type[Warning]):
177-
with contextlib.ExitStack() as stack:
178-
if condition:
179-
stack.enter_context(self.assertWarns(expected_warning))
180-
yield
181-
182175

183176
class Employee:
184177
pass
@@ -2467,7 +2460,7 @@ def test_basics_iterable_syntax(self):
24672460
self.assertEqual(Emp.__total__, True)
24682461

24692462
def test_basics_keywords_syntax(self):
2470-
with self.assertWarnsIf(sys.version_info >= (3, 11), DeprecationWarning):
2463+
with self.assertWarns(DeprecationWarning):
24712464
Emp = TypedDict('Emp', name=str, id=int)
24722465
self.assertIsSubclass(Emp, dict)
24732466
self.assertIsSubclass(Emp, typing.MutableMapping)
@@ -2483,7 +2476,7 @@ def test_basics_keywords_syntax(self):
24832476
self.assertEqual(Emp.__total__, True)
24842477

24852478
def test_typeddict_special_keyword_names(self):
2486-
with self.assertWarnsIf(sys.version_info >= (3, 11), DeprecationWarning):
2479+
with self.assertWarns(DeprecationWarning):
24872480
TD = TypedDict("TD", cls=type, self=object, typename=str, _typename=int,
24882481
fields=list, _fields=dict)
24892482
self.assertEqual(TD.__name__, 'TD')
@@ -2519,7 +2512,7 @@ def test_typeddict_create_errors(self):
25192512

25202513
def test_typeddict_errors(self):
25212514
Emp = TypedDict('Emp', {'name': str, 'id': int})
2522-
if hasattr(typing, "Required"):
2515+
if sys.version_info >= (3, 12):
25232516
self.assertEqual(TypedDict.__module__, 'typing')
25242517
else:
25252518
self.assertEqual(TypedDict.__module__, 'typing_extensions')
@@ -2532,7 +2525,7 @@ def test_typeddict_errors(self):
25322525
issubclass(dict, Emp)
25332526

25342527
if not TYPING_3_11_0:
2535-
with self.assertRaises(TypeError):
2528+
with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning):
25362529
TypedDict('Hi', x=1)
25372530
with self.assertRaises(TypeError):
25382531
TypedDict('Hi', [('x', int), ('y', 1)])
@@ -3036,6 +3029,49 @@ def test_get_type_hints_typeddict(self):
30363029
'year': NotRequired[Annotated[int, 2000]],
30373030
}
30383031

3032+
def test_orig_bases(self):
3033+
T = TypeVar('T')
3034+
3035+
class Parent(TypedDict):
3036+
pass
3037+
3038+
class Child(Parent):
3039+
pass
3040+
3041+
class OtherChild(Parent):
3042+
pass
3043+
3044+
class MixedChild(Child, OtherChild, Parent):
3045+
pass
3046+
3047+
class GenericParent(TypedDict, Generic[T]):
3048+
pass
3049+
3050+
class GenericChild(GenericParent[int]):
3051+
pass
3052+
3053+
class OtherGenericChild(GenericParent[str]):
3054+
pass
3055+
3056+
class MixedGenericChild(GenericChild, OtherGenericChild, GenericParent[float]):
3057+
pass
3058+
3059+
class MultipleGenericBases(GenericParent[int], GenericParent[float]):
3060+
pass
3061+
3062+
CallTypedDict = TypedDict('CallTypedDict', {})
3063+
3064+
self.assertEqual(Parent.__orig_bases__, (TypedDict,))
3065+
self.assertEqual(Child.__orig_bases__, (Parent,))
3066+
self.assertEqual(OtherChild.__orig_bases__, (Parent,))
3067+
self.assertEqual(MixedChild.__orig_bases__, (Child, OtherChild, Parent,))
3068+
self.assertEqual(GenericParent.__orig_bases__, (TypedDict, Generic[T]))
3069+
self.assertEqual(GenericChild.__orig_bases__, (GenericParent[int],))
3070+
self.assertEqual(OtherGenericChild.__orig_bases__, (GenericParent[str],))
3071+
self.assertEqual(MixedGenericChild.__orig_bases__, (GenericChild, OtherGenericChild, GenericParent[float]))
3072+
self.assertEqual(MultipleGenericBases.__orig_bases__, (GenericParent[int], GenericParent[float]))
3073+
self.assertEqual(CallTypedDict.__orig_bases__, (TypedDict,))
3074+
30393075

30403076
class TypeAliasTests(BaseTestCase):
30413077
def test_canonical_usage_with_variable_annotation(self):
@@ -3802,22 +3838,23 @@ def test_typing_extensions_defers_when_possible(self):
38023838
'overload',
38033839
'ParamSpec',
38043840
'Text',
3805-
'TypedDict',
38063841
'TypeVar',
38073842
'TypeVarTuple',
38083843
'TYPE_CHECKING',
38093844
'Final',
38103845
'get_type_hints',
3811-
'is_typeddict',
38123846
}
38133847
if sys.version_info < (3, 10):
38143848
exclude |= {'get_args', 'get_origin'}
38153849
if sys.version_info < (3, 10, 1):
38163850
exclude |= {"Literal"}
38173851
if sys.version_info < (3, 11):
3818-
exclude |= {'final', 'NamedTuple', 'Any'}
3852+
exclude |= {'final', 'Any'}
38193853
if sys.version_info < (3, 12):
3820-
exclude |= {'Protocol', 'runtime_checkable', 'SupportsIndex'}
3854+
exclude |= {
3855+
'Protocol', 'runtime_checkable', 'SupportsIndex', 'TypedDict',
3856+
'is_typeddict', 'NamedTuple',
3857+
}
38213858
for item in typing_extensions.__all__:
38223859
if item not in exclude and hasattr(typing, item):
38233860
self.assertIs(
@@ -3863,7 +3900,6 @@ def __add__(self, other):
38633900
return 0
38643901

38653902

3866-
@skipIf(TYPING_3_11_0, "These invariants should all be tested upstream on 3.11+")
38673903
class NamedTupleTests(BaseTestCase):
38683904
class NestedEmployee(NamedTuple):
38693905
name: str
@@ -4003,7 +4039,9 @@ class Y(Generic[T], NamedTuple):
40034039
self.assertIs(type(a), G)
40044040
self.assertEqual(a.x, 3)
40054041

4006-
with self.assertRaisesRegex(TypeError, 'Too many parameters'):
4042+
things = "arguments" if sys.version_info >= (3, 11) else "parameters"
4043+
4044+
with self.assertRaisesRegex(TypeError, f'Too many {things}'):
40074045
G[int, str]
40084046

40094047
@skipUnless(TYPING_3_9_0, "tuple.__class_getitem__ was added in 3.9")
@@ -4134,6 +4172,22 @@ def test_same_as_typing_NamedTuple_38_minus(self):
41344172
self.NestedEmployee._field_types
41354173
)
41364174

4175+
def test_orig_bases(self):
4176+
T = TypeVar('T')
4177+
4178+
class SimpleNamedTuple(NamedTuple):
4179+
pass
4180+
4181+
class GenericNamedTuple(NamedTuple, Generic[T]):
4182+
pass
4183+
4184+
self.assertEqual(SimpleNamedTuple.__orig_bases__, (NamedTuple,))
4185+
self.assertEqual(GenericNamedTuple.__orig_bases__, (NamedTuple, Generic[T]))
4186+
4187+
CallNamedTuple = NamedTuple('CallNamedTuple', [])
4188+
4189+
self.assertEqual(CallNamedTuple.__orig_bases__, (NamedTuple,))
4190+
41374191

41384192
class TypeVarLikeDefaultsTests(BaseTestCase):
41394193
def test_typevar(self):

src/typing_extensions.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -749,14 +749,16 @@ def __index__(self) -> int:
749749
pass
750750

751751

752-
if hasattr(typing, "Required"):
752+
if sys.version_info >= (3, 12):
753753
# The standard library TypedDict in Python 3.8 does not store runtime information
754754
# about which (if any) keys are optional. See https://bugs.python.org/issue38834
755755
# The standard library TypedDict in Python 3.9.0/1 does not honour the "total"
756756
# keyword with old-style TypedDict(). See https://bugs.python.org/issue42059
757757
# The standard library TypedDict below Python 3.11 does not store runtime
758758
# information about optional and required keys when using Required or NotRequired.
759759
# Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11.
760+
# Aaaand on 3.12 we add __orig_bases__ to TypedDict
761+
# to enable better runtime introspection.
760762
TypedDict = typing.TypedDict
761763
_TypedDictMeta = typing._TypedDictMeta
762764
is_typeddict = typing.is_typeddict
@@ -786,7 +788,6 @@ def _typeddict_new(*args, total=True, **kwargs):
786788
typename, args = args[0], args[1:] # allow the "_typename" keyword be passed
787789
elif '_typename' in kwargs:
788790
typename = kwargs.pop('_typename')
789-
import warnings
790791
warnings.warn("Passing '_typename' as keyword argument is deprecated",
791792
DeprecationWarning, stacklevel=2)
792793
else:
@@ -801,7 +802,6 @@ def _typeddict_new(*args, total=True, **kwargs):
801802
'were given')
802803
elif '_fields' in kwargs and len(kwargs) == 1:
803804
fields = kwargs.pop('_fields')
804-
import warnings
805805
warnings.warn("Passing '_fields' as keyword argument is deprecated",
806806
DeprecationWarning, stacklevel=2)
807807
else:
@@ -813,6 +813,15 @@ def _typeddict_new(*args, total=True, **kwargs):
813813
raise TypeError("TypedDict takes either a dict or keyword arguments,"
814814
" but not both")
815815

816+
if kwargs:
817+
warnings.warn(
818+
"The kwargs-based syntax for TypedDict definitions is deprecated, "
819+
"may be removed in a future version, and may not be "
820+
"understood by third-party type checkers.",
821+
DeprecationWarning,
822+
stacklevel=2,
823+
)
824+
816825
ns = {'__annotations__': dict(fields)}
817826
module = _caller()
818827
if module is not None:
@@ -844,9 +853,14 @@ def __new__(cls, name, bases, ns, total=True):
844853
# Instead, monkey-patch __bases__ onto the class after it's been created.
845854
tp_dict = super().__new__(cls, name, (dict,), ns)
846855

847-
if any(issubclass(base, typing.Generic) for base in bases):
856+
is_generic = any(issubclass(base, typing.Generic) for base in bases)
857+
858+
if is_generic:
848859
tp_dict.__bases__ = (typing.Generic, dict)
849860
_maybe_adjust_parameters(tp_dict)
861+
else:
862+
# generic TypedDicts get __orig_bases__ from Generic
863+
tp_dict.__orig_bases__ = bases or (TypedDict,)
850864

851865
annotations = {}
852866
own_annotations = ns.get('__annotations__', {})
@@ -2313,10 +2327,11 @@ def wrapper(*args, **kwargs):
23132327
typing._check_generic = _check_generic
23142328

23152329

2316-
# Backport typing.NamedTuple as it exists in Python 3.11.
2330+
# Backport typing.NamedTuple as it exists in Python 3.12.
23172331
# In 3.11, the ability to define generic `NamedTuple`s was supported.
23182332
# This was explicitly disallowed in 3.9-3.10, and only half-worked in <=3.8.
2319-
if sys.version_info >= (3, 11):
2333+
# On 3.12, we added __orig_bases__ to call-based NamedTuples
2334+
if sys.version_info >= (3, 12):
23202335
NamedTuple = typing.NamedTuple
23212336
else:
23222337
def _make_nmtuple(name, types, module, defaults=()):
@@ -2378,7 +2393,9 @@ def NamedTuple(__typename, __fields=None, **kwargs):
23782393
elif kwargs:
23792394
raise TypeError("Either list of fields or keywords"
23802395
" can be provided to NamedTuple, not both")
2381-
return _make_nmtuple(__typename, __fields, module=_caller())
2396+
nt = _make_nmtuple(__typename, __fields, module=_caller())
2397+
nt.__orig_bases__ = (NamedTuple,)
2398+
return nt
23822399

23832400
NamedTuple.__doc__ = typing.NamedTuple.__doc__
23842401
_NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {})

0 commit comments

Comments
 (0)