Skip to content

Commit d34d285

Browse files
authored
Add support for additional TypedDict methods (#6011)
This adds support for these methods through additional plugin hooks: * `pop` * `setdefault` * `update` (positional argument only) * `__delitem__` These methods also work and don't need plugin support: * `copy` * `has_key` (Python 2 only) * `viewitems` (Python 2 only) * `viewkeys` (Python 2 only) * `viewvalues` (Python 2 only) The base signatures for all of these methods are defined in `mypy_extensions._TypedDict`, which is a stub-only class only used internally by mypy. It becomes the new fallback type of all TypedDicts. Fixes #3843. Fixes #3550. There's some possible follow-up work that I'm leaving to other PRs, such as optimizing hook lookup through dictionaries in the default plugin, documenting the supported methods, and `update` with keyword arguments (#6019).
1 parent 779fa04 commit d34d285

14 files changed

+329
-51
lines changed

mypy/checker.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -3830,11 +3830,12 @@ def builtin_item_type(tp: Type) -> Optional[Type]:
38303830
elif isinstance(tp, TupleType) and all(not isinstance(it, AnyType) for it in tp.items):
38313831
return UnionType.make_simplified_union(tp.items) # this type is not externally visible
38323832
elif isinstance(tp, TypedDictType):
3833-
# TypedDict always has non-optional string keys.
3834-
if tp.fallback.type.fullname() == 'typing.Mapping':
3835-
return tp.fallback.args[0]
3836-
elif tp.fallback.type.bases[0].type.fullname() == 'typing.Mapping':
3837-
return tp.fallback.type.bases[0].args[0]
3833+
# TypedDict always has non-optional string keys. Find the key type from the Mapping
3834+
# base class.
3835+
for base in tp.fallback.type.mro:
3836+
if base.fullname() == 'typing.Mapping':
3837+
return map_instance_to_supertype(tp.fallback, base).args[0]
3838+
assert False, 'No Mapping base class found for TypedDict fallback'
38383839
return None
38393840

38403841

mypy/messages.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -1157,11 +1157,11 @@ def unexpected_typeddict_keys(
11571157
format_key_list(extra, short=True), self.format(typ)),
11581158
context)
11591159
return
1160-
if not expected_keys:
1161-
expected = '(no keys)'
1162-
else:
1163-
expected = format_key_list(expected_keys)
11641160
found = format_key_list(actual_keys, short=True)
1161+
if not expected_keys:
1162+
self.fail('Unexpected TypedDict {}'.format(found), context)
1163+
return
1164+
expected = format_key_list(expected_keys)
11651165
if actual_keys and actual_set < expected_set:
11661166
found = 'only {}'.format(found)
11671167
self.fail('Expected {} but found {}'.format(expected, found), context)
@@ -1185,6 +1185,18 @@ def typeddict_key_not_found(
11851185
else:
11861186
self.fail("TypedDict {} has no key '{}'".format(self.format(typ), item_name), context)
11871187

1188+
def typeddict_key_cannot_be_deleted(
1189+
self,
1190+
typ: TypedDictType,
1191+
item_name: str,
1192+
context: Context) -> None:
1193+
if typ.is_anonymous():
1194+
self.fail("TypedDict key '{}' cannot be deleted".format(item_name),
1195+
context)
1196+
else:
1197+
self.fail("Key '{}' of TypedDict {} cannot be deleted".format(
1198+
item_name, self.format(typ)), context)
1199+
11881200
def type_arguments_not_allowed(self, context: Context) -> None:
11891201
self.fail('Parameterized generics cannot be used with class or instance checks', context)
11901202

mypy/plugin.py

+149
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, TypeVarType,
1717
AnyType, TypeList, UnboundType, TypeOfAny, TypeType,
1818
)
19+
from mypy import messages
1920
from mypy.messages import MessageBuilder
2021
from mypy.options import Options
2122
import mypy.interpreted_plugin
@@ -61,6 +62,10 @@ class CheckerPluginInterface:
6162
msg = None # type: MessageBuilder
6263
options = None # type: Options
6364

65+
@abstractmethod
66+
def fail(self, msg: str, ctx: Context) -> None:
67+
raise NotImplementedError
68+
6469
@abstractmethod
6570
def named_generic_type(self, name: str, args: List[Type]) -> Instance:
6671
raise NotImplementedError
@@ -400,6 +405,14 @@ def get_method_signature_hook(self, fullname: str
400405

401406
if fullname == 'typing.Mapping.get':
402407
return typed_dict_get_signature_callback
408+
elif fullname == 'mypy_extensions._TypedDict.setdefault':
409+
return typed_dict_setdefault_signature_callback
410+
elif fullname == 'mypy_extensions._TypedDict.pop':
411+
return typed_dict_pop_signature_callback
412+
elif fullname == 'mypy_extensions._TypedDict.update':
413+
return typed_dict_update_signature_callback
414+
elif fullname == 'mypy_extensions._TypedDict.__delitem__':
415+
return typed_dict_delitem_signature_callback
403416
elif fullname == 'ctypes.Array.__setitem__':
404417
return ctypes.array_setitem_callback
405418
return None
@@ -412,6 +425,12 @@ def get_method_hook(self, fullname: str
412425
return typed_dict_get_callback
413426
elif fullname == 'builtins.int.__pow__':
414427
return int_pow_callback
428+
elif fullname == 'mypy_extensions._TypedDict.setdefault':
429+
return typed_dict_setdefault_callback
430+
elif fullname == 'mypy_extensions._TypedDict.pop':
431+
return typed_dict_pop_callback
432+
elif fullname == 'mypy_extensions._TypedDict.__delitem__':
433+
return typed_dict_delitem_callback
415434
elif fullname == 'ctypes.Array.__getitem__':
416435
return ctypes.array_getitem_callback
417436
elif fullname == 'ctypes.Array.__iter__':
@@ -544,6 +563,136 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
544563
return ctx.default_return_type
545564

546565

566+
def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
567+
"""Try to infer a better signature type for TypedDict.pop.
568+
569+
This is used to get better type context for the second argument that
570+
depends on a TypedDict value type.
571+
"""
572+
signature = ctx.default_signature
573+
str_type = ctx.api.named_generic_type('builtins.str', [])
574+
if (isinstance(ctx.type, TypedDictType)
575+
and len(ctx.args) == 2
576+
and len(ctx.args[0]) == 1
577+
and isinstance(ctx.args[0][0], StrExpr)
578+
and len(signature.arg_types) == 2
579+
and len(signature.variables) == 1
580+
and len(ctx.args[1]) == 1):
581+
key = ctx.args[0][0].value
582+
value_type = ctx.type.items.get(key)
583+
if value_type:
584+
# Tweak the signature to include the value type as context. It's
585+
# only needed for type inference since there's a union with a type
586+
# variable that accepts everything.
587+
tv = TypeVarType(signature.variables[0])
588+
typ = UnionType.make_simplified_union([value_type, tv])
589+
return signature.copy_modified(
590+
arg_types=[str_type, typ],
591+
ret_type=typ)
592+
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
593+
594+
595+
def typed_dict_pop_callback(ctx: MethodContext) -> Type:
596+
"""Type check and infer a precise return type for TypedDict.pop."""
597+
if (isinstance(ctx.type, TypedDictType)
598+
and len(ctx.arg_types) >= 1
599+
and len(ctx.arg_types[0]) == 1):
600+
if isinstance(ctx.args[0][0], StrExpr):
601+
key = ctx.args[0][0].value
602+
if key in ctx.type.required_keys:
603+
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
604+
value_type = ctx.type.items.get(key)
605+
if value_type:
606+
if len(ctx.args[1]) == 0:
607+
return value_type
608+
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
609+
and len(ctx.args[1]) == 1):
610+
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
611+
else:
612+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
613+
return AnyType(TypeOfAny.from_error)
614+
else:
615+
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
616+
return AnyType(TypeOfAny.from_error)
617+
return ctx.default_return_type
618+
619+
620+
def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
621+
"""Try to infer a better signature type for TypedDict.setdefault.
622+
623+
This is used to get better type context for the second argument that
624+
depends on a TypedDict value type.
625+
"""
626+
signature = ctx.default_signature
627+
str_type = ctx.api.named_generic_type('builtins.str', [])
628+
if (isinstance(ctx.type, TypedDictType)
629+
and len(ctx.args) == 2
630+
and len(ctx.args[0]) == 1
631+
and isinstance(ctx.args[0][0], StrExpr)
632+
and len(signature.arg_types) == 2
633+
and len(ctx.args[1]) == 1):
634+
key = ctx.args[0][0].value
635+
value_type = ctx.type.items.get(key)
636+
if value_type:
637+
return signature.copy_modified(arg_types=[str_type, value_type])
638+
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
639+
640+
641+
def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
642+
"""Type check TypedDict.setdefault and infer a precise return type."""
643+
if (isinstance(ctx.type, TypedDictType)
644+
and len(ctx.arg_types) == 2
645+
and len(ctx.arg_types[0]) == 1):
646+
if isinstance(ctx.args[0][0], StrExpr):
647+
key = ctx.args[0][0].value
648+
value_type = ctx.type.items.get(key)
649+
if value_type:
650+
return value_type
651+
else:
652+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
653+
return AnyType(TypeOfAny.from_error)
654+
else:
655+
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
656+
return AnyType(TypeOfAny.from_error)
657+
return ctx.default_return_type
658+
659+
660+
def typed_dict_delitem_signature_callback(ctx: MethodSigContext) -> CallableType:
661+
# Replace NoReturn as the argument type.
662+
str_type = ctx.api.named_generic_type('builtins.str', [])
663+
return ctx.default_signature.copy_modified(arg_types=[str_type])
664+
665+
666+
def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
667+
"""Type check TypedDict.__delitem__."""
668+
if (isinstance(ctx.type, TypedDictType)
669+
and len(ctx.arg_types) == 1
670+
and len(ctx.arg_types[0]) == 1):
671+
if isinstance(ctx.args[0][0], StrExpr):
672+
key = ctx.args[0][0].value
673+
if key in ctx.type.required_keys:
674+
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
675+
elif key not in ctx.type.items:
676+
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
677+
else:
678+
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
679+
return AnyType(TypeOfAny.from_error)
680+
return ctx.default_return_type
681+
682+
683+
def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
684+
"""Try to infer a better signature type for TypedDict.update."""
685+
signature = ctx.default_signature
686+
if (isinstance(ctx.type, TypedDictType)
687+
and len(signature.arg_types) == 1):
688+
arg_type = signature.arg_types[0]
689+
assert isinstance(arg_type, TypedDictType)
690+
arg_type = arg_type.as_anonymous()
691+
arg_type = arg_type.copy_modified(required_keys=set())
692+
return signature.copy_modified(arg_types=[arg_type])
693+
return signature
694+
695+
547696
def int_pow_callback(ctx: MethodContext) -> Type:
548697
"""Infer a more precise return type for int.__pow__."""
549698
if (len(ctx.arg_types) == 1

mypy/semanal_shared.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,13 @@ def create_indirect_imported_name(file_node: MypyFile,
144144
def set_callable_name(sig: Type, fdef: FuncDef) -> Type:
145145
if isinstance(sig, FunctionLike):
146146
if fdef.info:
147+
if fdef.info.fullname() == 'mypy_extensions._TypedDict':
148+
# Avoid exposing the internal _TypedDict name.
149+
class_name = 'TypedDict'
150+
else:
151+
class_name = fdef.info.name()
147152
return sig.with_name(
148-
'{} of {}'.format(fdef.name(), fdef.info.name()))
153+
'{} of {}'.format(fdef.name(), class_name))
149154
else:
150155
return sig.with_name(fdef.name())
151156
else:

mypy/semanal_typeddict.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,8 @@ def fail_typeddict_arg(self, message: str,
276276
def build_typeddict_typeinfo(self, name: str, items: List[str],
277277
types: List[Type],
278278
required_keys: Set[str]) -> TypeInfo:
279-
fallback = (self.api.named_type_or_none('typing.Mapping',
280-
[self.api.named_type('__builtins__.str'),
281-
self.api.named_type('__builtins__.object')])
282-
or self.api.named_type('__builtins__.object'))
279+
fallback = self.api.named_type_or_none('mypy_extensions._TypedDict', [])
280+
assert fallback is not None
283281
info = self.api.basic_new_typeinfo(name, fallback)
284282
info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), required_keys,
285283
fallback)

mypy/server/deps.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -790,8 +790,8 @@ def add_dependency(self, trigger: str, target: Optional[str] = None) -> None:
790790
791791
If the target is not given explicitly, use the current target.
792792
"""
793-
if trigger.startswith(('<builtins.', '<typing.')):
794-
# Don't track dependencies to certain builtins to keep the size of
793+
if trigger.startswith(('<builtins.', '<typing.', '<mypy_extensions.')):
794+
# Don't track dependencies to certain library modules to keep the size of
795795
# the dependencies manageable. These dependencies should only
796796
# change on mypy version updates, which will require a full rebuild
797797
# anyway.

mypy/test/testsemanal.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def test_semanal(testcase: DataDrivenTestCase) -> None:
8181
'mypy_extensions.pyi',
8282
'typing_extensions.pyi',
8383
'abc.pyi',
84-
'collections.pyi'))
84+
'collections.pyi',
85+
'sys.pyi'))
8586
and not os.path.basename(f.path).startswith('_')
8687
and not os.path.splitext(
8788
os.path.basename(f.path))[0].endswith('_')):

mypy/types.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -1171,12 +1171,23 @@ def slice(self, begin: Optional[int], stride: Optional[int],
11711171

11721172

11731173
class TypedDictType(Type):
1174-
"""The type of a TypedDict instance. TypedDict(K1=VT1, ..., Kn=VTn)
1174+
"""Type of TypedDict object {'k1': v1, ..., 'kn': vn}.
11751175
1176-
A TypedDictType can be either named or anonymous.
1177-
If it is anonymous then its fallback will be an Instance of Mapping[str, V].
1178-
If it is named then its fallback will be an Instance of the named type (ex: "Point")
1179-
whose TypeInfo has a typeddict_type that is anonymous.
1176+
A TypedDict object is a dictionary with specific string (literal) keys. Each
1177+
key has a value with a distinct type that depends on the key. TypedDict objects
1178+
are normal dict objects at runtime.
1179+
1180+
A TypedDictType can be either named or anonymous. If it's anonymous, its
1181+
fallback will mypy_extensions._TypedDict (Instance). _TypedDict is a subclass
1182+
of Mapping[str, object] and defines all non-mapping dict methods that TypedDict
1183+
supports. Some dict methods are unsafe and not supported. _TypedDict isn't defined
1184+
at runtime.
1185+
1186+
If a TypedDict is named, its fallback will be an Instance of the named type
1187+
(ex: "Point") whose TypeInfo has a typeddict_type that is anonymous. This
1188+
is similar to how named tuples work.
1189+
1190+
TODO: The fallback structure is perhaps overly complicated.
11801191
"""
11811192

11821193
items = None # type: OrderedDict[str, Type] # item_name -> item_type
@@ -1226,7 +1237,7 @@ def deserialize(cls, data: JsonDict) -> 'TypedDictType':
12261237
Instance.deserialize(data['fallback']))
12271238

12281239
def is_anonymous(self) -> bool:
1229-
return self.fallback.type.fullname() == 'typing.Mapping'
1240+
return self.fallback.type.fullname() == 'mypy_extensions._TypedDict'
12301241

12311242
def as_anonymous(self) -> 'TypedDictType':
12321243
if self.is_anonymous():
@@ -1249,10 +1260,7 @@ def copy_modified(self, *, fallback: Optional[Instance] = None,
12491260

12501261
def create_anonymous_fallback(self, *, value_type: Type) -> Instance:
12511262
anonymous = self.as_anonymous()
1252-
return anonymous.fallback.copy_modified(args=[ # i.e. Mapping
1253-
anonymous.fallback.args[0], # i.e. str
1254-
value_type
1255-
])
1263+
return anonymous.fallback
12561264

12571265
def names_are_wider_than(self, other: 'TypedDictType') -> bool:
12581266
return len(other.items.keys() - self.items.keys()) == 0
@@ -1821,13 +1829,10 @@ def item_str(name: str, typ: str) -> str:
18211829
s = '{' + ', '.join(item_str(name, typ.accept(self))
18221830
for name, typ in t.items.items()) + '}'
18231831
prefix = ''
1824-
suffix = ''
18251832
if t.fallback and t.fallback.type:
1826-
if t.fallback.type.fullname() != 'typing.Mapping':
1833+
if t.fallback.type.fullname() != 'mypy_extensions._TypedDict':
18271834
prefix = repr(t.fallback.type.fullname()) + ', '
1828-
else:
1829-
suffix = ', fallback={}'.format(t.fallback.accept(self))
1830-
return 'TypedDict({}{}{})'.format(prefix, s, suffix)
1835+
return 'TypedDict({}{})'.format(prefix, s)
18311836

18321837
def visit_raw_literal_type(self, t: RawLiteralType) -> str:
18331838
return repr(t.value)

0 commit comments

Comments
 (0)