Skip to content

Commit a9d72e4

Browse files
ilevkivskyigvanrossum
authored andcommitted
Runtime implementation of TypedDict extension (#2552)
This was initially proposed in python/typing#322. It works on Python 2 and 3.
1 parent e9d28a0 commit a9d72e4

File tree

4 files changed

+199
-11
lines changed

4 files changed

+199
-11
lines changed

extensions/mypy_extensions.py

+76-8
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,84 @@
77

88
# NOTE: This module must support Python 2.7 in addition to Python 3.x
99

10+
import sys
11+
# _type_check is NOT a part of public typing API, it is used here only to mimic
12+
# the (convenient) behavior of types provided by typing module.
13+
from typing import _type_check # type: ignore
1014

11-
def TypedDict(typename, fields):
12-
"""TypedDict creates a dictionary type that expects all of its
15+
16+
def _check_fails(cls, other):
17+
try:
18+
if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']:
19+
# Typed dicts are only for static structural subtyping.
20+
raise TypeError('TypedDict does not support instance and class checks')
21+
except (AttributeError, ValueError):
22+
pass
23+
return False
24+
25+
def _dict_new(cls, *args, **kwargs):
26+
return dict(*args, **kwargs)
27+
28+
def _typeddict_new(cls, _typename, _fields=None, **kwargs):
29+
if _fields is None:
30+
_fields = kwargs
31+
elif kwargs:
32+
raise TypeError("TypedDict takes either a dict or keyword arguments,"
33+
" but not both")
34+
return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields)})
35+
36+
class _TypedDictMeta(type):
37+
def __new__(cls, name, bases, ns):
38+
# Create new typed dict class object.
39+
# This method is called directly when TypedDict is subclassed,
40+
# or via _typeddict_new when TypedDict is instantiated. This way
41+
# TypedDict supports all three syntaxes described in its docstring.
42+
# Subclasses and instanes of TypedDict return actual dictionaries
43+
# via _dict_new.
44+
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
45+
tp_dict = super(_TypedDictMeta, cls).__new__(cls, name, (dict,), ns)
46+
try:
47+
# Setting correct module is necessary to make typed dict classes pickleable.
48+
tp_dict.__module__ = sys._getframe(2).f_globals.get('__name__', '__main__')
49+
except (AttributeError, ValueError):
50+
pass
51+
anns = ns.get('__annotations__', {})
52+
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
53+
anns = {n: _type_check(tp, msg) for n, tp in anns.items()}
54+
for base in bases:
55+
anns.update(base.__dict__.get('__annotations__', {}))
56+
tp_dict.__annotations__ = anns
57+
return tp_dict
58+
59+
__instancecheck__ = __subclasscheck__ = _check_fails
60+
61+
62+
TypedDict = _TypedDictMeta('TypedDict', (dict,), {})
63+
TypedDict.__module__ = __name__
64+
TypedDict.__doc__ = \
65+
"""A simple typed name space. At runtime it is equivalent to a plain dict.
66+
67+
TypedDict creates a dictionary type that expects all of its
1368
instances to have a certain set of keys, with each key
1469
associated with a value of a consistent type. This expectation
1570
is not checked at runtime but is only enforced by typecheckers.
16-
"""
17-
def new_dict(*args, **kwargs):
18-
return dict(*args, **kwargs)
71+
Usage::
72+
73+
Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})
74+
a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK
75+
b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check
76+
assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
77+
78+
The type info could be accessed via Point2D.__annotations__. TypedDict
79+
supports two additional equivalent forms::
1980
20-
new_dict.__name__ = typename
21-
new_dict.__supertype__ = dict
22-
return new_dict
81+
Point2D = TypedDict('Point2D', x=int, y=int, label=str)
82+
83+
class Point2D(TypedDict):
84+
x: int
85+
y: int
86+
label: str
87+
88+
The latter syntax is only supported in Python 3.6+, while two other
89+
syntax forms work for Python 2.7 and 3.2+
90+
"""

mypy/semanal.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1873,8 +1873,9 @@ def process_typeddict_definition(self, s: AssignmentStmt) -> None:
18731873
return
18741874
# Yes, it's a valid TypedDict definition. Add it to the symbol table.
18751875
node = self.lookup(name, s)
1876-
node.kind = GDEF # TODO locally defined TypedDict
1877-
node.node = typed_dict
1876+
if node:
1877+
node.kind = GDEF # TODO locally defined TypedDict
1878+
node.node = typed_dict
18781879

18791880
def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]:
18801881
"""Check if a call defines a TypedDict.

mypy/test/testextensions.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import sys
2+
import pickle
3+
import typing
4+
try:
5+
import collections.abc as collections_abc
6+
except ImportError:
7+
import collections as collections_abc # type: ignore # PY32 and earlier
8+
from unittest import TestCase, main, skipUnless
9+
sys.path[0:0] = ['extensions']
10+
from mypy_extensions import TypedDict
11+
12+
13+
class BaseTestCase(TestCase):
14+
15+
def assertIsSubclass(self, cls, class_or_tuple, msg=None):
16+
if not issubclass(cls, class_or_tuple):
17+
message = '%r is not a subclass of %r' % (cls, class_or_tuple)
18+
if msg is not None:
19+
message += ' : %s' % msg
20+
raise self.failureException(message)
21+
22+
def assertNotIsSubclass(self, cls, class_or_tuple, msg=None):
23+
if issubclass(cls, class_or_tuple):
24+
message = '%r is a subclass of %r' % (cls, class_or_tuple)
25+
if msg is not None:
26+
message += ' : %s' % msg
27+
raise self.failureException(message)
28+
29+
30+
PY36 = sys.version_info[:2] >= (3, 6)
31+
32+
PY36_TESTS = """
33+
Label = TypedDict('Label', [('label', str)])
34+
35+
class Point2D(TypedDict):
36+
x: int
37+
y: int
38+
39+
class LabelPoint2D(Point2D, Label): ...
40+
"""
41+
42+
if PY36:
43+
exec(PY36_TESTS)
44+
45+
46+
class TypedDictTests(BaseTestCase):
47+
48+
def test_basics_iterable_syntax(self):
49+
Emp = TypedDict('Emp', {'name': str, 'id': int})
50+
self.assertIsSubclass(Emp, dict)
51+
self.assertIsSubclass(Emp, typing.MutableMapping)
52+
self.assertNotIsSubclass(Emp, collections_abc.Sequence)
53+
jim = Emp(name='Jim', id=1)
54+
self.assertIs(type(jim), dict)
55+
self.assertEqual(jim['name'], 'Jim')
56+
self.assertEqual(jim['id'], 1)
57+
self.assertEqual(Emp.__name__, 'Emp')
58+
self.assertEqual(Emp.__module__, 'mypy.test.testextensions')
59+
self.assertEqual(Emp.__bases__, (dict,))
60+
self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
61+
62+
def test_basics_keywords_syntax(self):
63+
Emp = TypedDict('Emp', name=str, id=int)
64+
self.assertIsSubclass(Emp, dict)
65+
self.assertIsSubclass(Emp, typing.MutableMapping)
66+
self.assertNotIsSubclass(Emp, collections_abc.Sequence)
67+
jim = Emp(name='Jim', id=1) # type: ignore # mypy doesn't support keyword syntax yet
68+
self.assertIs(type(jim), dict)
69+
self.assertEqual(jim['name'], 'Jim')
70+
self.assertEqual(jim['id'], 1)
71+
self.assertEqual(Emp.__name__, 'Emp')
72+
self.assertEqual(Emp.__module__, 'mypy.test.testextensions')
73+
self.assertEqual(Emp.__bases__, (dict,))
74+
self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
75+
76+
def test_typeddict_errors(self):
77+
Emp = TypedDict('Emp', {'name': str, 'id': int})
78+
self.assertEqual(TypedDict.__module__, 'mypy_extensions')
79+
jim = Emp(name='Jim', id=1)
80+
with self.assertRaises(TypeError):
81+
isinstance({}, Emp)
82+
with self.assertRaises(TypeError):
83+
isinstance(jim, Emp)
84+
with self.assertRaises(TypeError):
85+
issubclass(dict, Emp)
86+
with self.assertRaises(TypeError):
87+
TypedDict('Hi', x=1)
88+
with self.assertRaises(TypeError):
89+
TypedDict('Hi', [('x', int), ('y', 1)])
90+
with self.assertRaises(TypeError):
91+
TypedDict('Hi', [('x', int)], y=int)
92+
93+
@skipUnless(PY36, 'Python 3.6 required')
94+
def test_py36_class_syntax_usage(self):
95+
self.assertEqual(LabelPoint2D.__annotations__, {'x': int, 'y': int, 'label': str}) # noqa
96+
self.assertEqual(LabelPoint2D.__bases__, (dict,)) # noqa
97+
self.assertNotIsSubclass(LabelPoint2D, typing.Sequence) # noqa
98+
not_origin = Point2D(x=0, y=1) # noqa
99+
self.assertEqual(not_origin['x'], 0)
100+
self.assertEqual(not_origin['y'], 1)
101+
other = LabelPoint2D(x=0, y=1, label='hi') # noqa
102+
self.assertEqual(other['label'], 'hi')
103+
104+
def test_pickle(self):
105+
global EmpD # pickle wants to reference the class by name
106+
EmpD = TypedDict('EmpD', name=str, id=int)
107+
jane = EmpD({'name': 'jane', 'id': 37})
108+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
109+
z = pickle.dumps(jane, proto)
110+
jane2 = pickle.loads(z)
111+
self.assertEqual(jane2, jane)
112+
self.assertEqual(jane2, {'name': 'jane', 'id': 37})
113+
ZZ = pickle.dumps(EmpD, proto)
114+
EmpDnew = pickle.loads(ZZ)
115+
self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane)
116+
117+
118+
if __name__ == '__main__':
119+
main()

runtests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def add_imports(driver: Driver) -> None:
207207

208208

209209
PYTEST_FILES = ['mypy/test/{}.py'.format(name) for name in [
210-
'testcheck',
210+
'testcheck', 'testextensions',
211211
]]
212212

213213

0 commit comments

Comments
 (0)