Skip to content

Add support for additional TypedDict methods #6011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3830,11 +3830,12 @@ def builtin_item_type(tp: Type) -> Optional[Type]:
elif isinstance(tp, TupleType) and all(not isinstance(it, AnyType) for it in tp.items):
return UnionType.make_simplified_union(tp.items) # this type is not externally visible
elif isinstance(tp, TypedDictType):
# TypedDict always has non-optional string keys.
if tp.fallback.type.fullname() == 'typing.Mapping':
return tp.fallback.args[0]
elif tp.fallback.type.bases[0].type.fullname() == 'typing.Mapping':
return tp.fallback.type.bases[0].args[0]
# TypedDict always has non-optional string keys. Find the key type from the Mapping
# base class.
for base in tp.fallback.type.mro:
if base.fullname() == 'typing.Mapping':
return map_instance_to_supertype(tp.fallback, base).args[0]
assert False, 'No Mapping base class found for TypedDict fallback'
return None


Expand Down
20 changes: 16 additions & 4 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,11 +1157,11 @@ def unexpected_typeddict_keys(
format_key_list(extra, short=True), self.format(typ)),
context)
return
if not expected_keys:
expected = '(no keys)'
else:
expected = format_key_list(expected_keys)
found = format_key_list(actual_keys, short=True)
if not expected_keys:
self.fail('Unexpected TypedDict {}'.format(found), context)
return
expected = format_key_list(expected_keys)
if actual_keys and actual_set < expected_set:
found = 'only {}'.format(found)
self.fail('Expected {} but found {}'.format(expected, found), context)
Expand All @@ -1185,6 +1185,18 @@ def typeddict_key_not_found(
else:
self.fail("TypedDict {} has no key '{}'".format(self.format(typ), item_name), context)

def typeddict_key_cannot_be_deleted(
self,
typ: TypedDictType,
item_name: str,
context: Context) -> None:
if typ.is_anonymous():
self.fail("TypedDict key '{}' cannot be deleted".format(item_name),
context)
else:
self.fail("Key '{}' of TypedDict {} cannot be deleted".format(
item_name, self.format(typ)), context)

def type_arguments_not_allowed(self, context: Context) -> None:
self.fail('Parameterized generics cannot be used with class or instance checks', context)

Expand Down
149 changes: 149 additions & 0 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, TypeVarType,
AnyType, TypeList, UnboundType, TypeOfAny, TypeType,
)
from mypy import messages
from mypy.messages import MessageBuilder
from mypy.options import Options
import mypy.interpreted_plugin
Expand Down Expand Up @@ -61,6 +62,10 @@ class CheckerPluginInterface:
msg = None # type: MessageBuilder
options = None # type: Options

@abstractmethod
def fail(self, msg: str, ctx: Context) -> None:
raise NotImplementedError

@abstractmethod
def named_generic_type(self, name: str, args: List[Type]) -> Instance:
raise NotImplementedError
Expand Down Expand Up @@ -400,6 +405,14 @@ def get_method_signature_hook(self, fullname: str

if fullname == 'typing.Mapping.get':
return typed_dict_get_signature_callback
elif fullname == 'mypy_extensions._TypedDict.setdefault':
return typed_dict_setdefault_signature_callback
elif fullname == 'mypy_extensions._TypedDict.pop':
return typed_dict_pop_signature_callback
elif fullname == 'mypy_extensions._TypedDict.update':
return typed_dict_update_signature_callback
elif fullname == 'mypy_extensions._TypedDict.__delitem__':
return typed_dict_delitem_signature_callback
elif fullname == 'ctypes.Array.__setitem__':
return ctypes.array_setitem_callback
return None
Expand All @@ -412,6 +425,12 @@ def get_method_hook(self, fullname: str
return typed_dict_get_callback
elif fullname == 'builtins.int.__pow__':
return int_pow_callback
elif fullname == 'mypy_extensions._TypedDict.setdefault':
return typed_dict_setdefault_callback
elif fullname == 'mypy_extensions._TypedDict.pop':
return typed_dict_pop_callback
elif fullname == 'mypy_extensions._TypedDict.__delitem__':
return typed_dict_delitem_callback
elif fullname == 'ctypes.Array.__getitem__':
return ctypes.array_getitem_callback
elif fullname == 'ctypes.Array.__iter__':
Expand Down Expand Up @@ -544,6 +563,136 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
return ctx.default_return_type


def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.pop.

This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
str_type = ctx.api.named_generic_type('builtins.str', [])
if (isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1
and len(ctx.args[1]) == 1):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = TypeVarType(signature.variables[0])
typ = UnionType.make_simplified_union([value_type, tv])
return signature.copy_modified(
arg_types=[str_type, typ],
ret_type=typ)
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])


def typed_dict_pop_callback(ctx: MethodContext) -> Type:
"""Type check and infer a precise return type for TypedDict.pop."""
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1):
if isinstance(ctx.args[0][0], StrExpr):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point we should also start using "constant folding", not in this PR however.

key = ctx.args[0][0].value
if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
value_type = ctx.type.items.get(key)
if value_type:
if len(ctx.args[1]) == 0:
return value_type
elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
and len(ctx.args[1]) == 1):
return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]])
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
else:
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)
return ctx.default_return_type


def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.setdefault.

This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
str_type = ctx.api.named_generic_type('builtins.str', [])
if (isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(ctx.args[1]) == 1):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
return signature.copy_modified(arg_types=[str_type, value_type])
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])


def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
"""Type check TypedDict.setdefault and infer a precise return type."""
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 2
and len(ctx.arg_types[0]) == 1):
if isinstance(ctx.args[0][0], StrExpr):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
return value_type
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
else:
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)
return ctx.default_return_type


def typed_dict_delitem_signature_callback(ctx: MethodSigContext) -> CallableType:
# Replace NoReturn as the argument type.
str_type = ctx.api.named_generic_type('builtins.str', [])
return ctx.default_signature.copy_modified(arg_types=[str_type])


def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
"""Type check TypedDict.__delitem__."""
if (isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 1
and len(ctx.arg_types[0]) == 1):
if isinstance(ctx.args[0][0], StrExpr):
key = ctx.args[0][0].value
if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
elif key not in ctx.type.items:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
else:
ctx.api.fail(messages.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
return AnyType(TypeOfAny.from_error)
return ctx.default_return_type


def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.update."""
signature = ctx.default_signature
if (isinstance(ctx.type, TypedDictType)
and len(signature.arg_types) == 1):
arg_type = signature.arg_types[0]
assert isinstance(arg_type, TypedDictType)
arg_type = arg_type.as_anonymous()
arg_type = arg_type.copy_modified(required_keys=set())
return signature.copy_modified(arg_types=[arg_type])
return signature
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a TODO somewhere for .update(x=1, y=2)? Or is this tracked in an issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def int_pow_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__pow__."""
if (len(ctx.arg_types) == 1
Expand Down
7 changes: 6 additions & 1 deletion mypy/semanal_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,13 @@ def create_indirect_imported_name(file_node: MypyFile,
def set_callable_name(sig: Type, fdef: FuncDef) -> Type:
if isinstance(sig, FunctionLike):
if fdef.info:
if fdef.info.fullname() == 'mypy_extensions._TypedDict':
# Avoid exposing the internal _TypedDict name.
class_name = 'TypedDict'
else:
class_name = fdef.info.name()
return sig.with_name(
'{} of {}'.format(fdef.name(), fdef.info.name()))
'{} of {}'.format(fdef.name(), class_name))
else:
return sig.with_name(fdef.name())
else:
Expand Down
6 changes: 2 additions & 4 deletions mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,8 @@ def fail_typeddict_arg(self, message: str,
def build_typeddict_typeinfo(self, name: str, items: List[str],
types: List[Type],
required_keys: Set[str]) -> TypeInfo:
fallback = (self.api.named_type_or_none('typing.Mapping',
[self.api.named_type('__builtins__.str'),
self.api.named_type('__builtins__.object')])
or self.api.named_type('__builtins__.object'))
fallback = self.api.named_type_or_none('mypy_extensions._TypedDict', [])
assert fallback is not None
info = self.api.basic_new_typeinfo(name, fallback)
info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), required_keys,
fallback)
Expand Down
4 changes: 2 additions & 2 deletions mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,8 +790,8 @@ def add_dependency(self, trigger: str, target: Optional[str] = None) -> None:

If the target is not given explicitly, use the current target.
"""
if trigger.startswith(('<builtins.', '<typing.')):
# Don't track dependencies to certain builtins to keep the size of
if trigger.startswith(('<builtins.', '<typing.', '<mypy_extensions.')):
# Don't track dependencies to certain library modules to keep the size of
# the dependencies manageable. These dependencies should only
# change on mypy version updates, which will require a full rebuild
# anyway.
Expand Down
3 changes: 2 additions & 1 deletion mypy/test/testsemanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_semanal(testcase: DataDrivenTestCase) -> None:
'mypy_extensions.pyi',
'typing_extensions.pyi',
'abc.pyi',
'collections.pyi'))
'collections.pyi',
'sys.pyi'))
and not os.path.basename(f.path).startswith('_')
and not os.path.splitext(
os.path.basename(f.path))[0].endswith('_')):
Expand Down
35 changes: 20 additions & 15 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,12 +1172,23 @@ def slice(self, begin: Optional[int], stride: Optional[int],


class TypedDictType(Type):
"""The type of a TypedDict instance. TypedDict(K1=VT1, ..., Kn=VTn)
"""Type of TypedDict object {'k1': v1, ..., 'kn': vn}.

A TypedDictType can be either named or anonymous.
If it is anonymous then its fallback will be an Instance of Mapping[str, V].
If it is named then its fallback will be an Instance of the named type (ex: "Point")
whose TypeInfo has a typeddict_type that is anonymous.
A TypedDict object is a dictionary with specific string (literal) keys. Each
key has a value with a distinct type that depends on the key. TypedDict objects
are normal dict objects at runtime.

A TypedDictType can be either named or anonymous. If it's anonymous, its
fallback will mypy_extensions._TypedDict (Instance). _TypedDict is a subclass
of Mapping[str, object] and defines all non-mapping dict methods that TypedDict
supports. Some dict methods are unsafe and not supported. _TypedDict isn't defined
at runtime.

If a TypedDict is named, its fallback will be an Instance of the named type
(ex: "Point") whose TypeInfo has a typeddict_type that is anonymous. This
is similar to how named tuples work.

TODO: The fallback structure is perhaps overly complicated.
"""

items = None # type: OrderedDict[str, Type] # item_name -> item_type
Expand Down Expand Up @@ -1227,7 +1238,7 @@ def deserialize(cls, data: JsonDict) -> 'TypedDictType':
Instance.deserialize(data['fallback']))

def is_anonymous(self) -> bool:
return self.fallback.type.fullname() == 'typing.Mapping'
return self.fallback.type.fullname() == 'mypy_extensions._TypedDict'

def as_anonymous(self) -> 'TypedDictType':
if self.is_anonymous():
Expand All @@ -1250,10 +1261,7 @@ def copy_modified(self, *, fallback: Optional[Instance] = None,

def create_anonymous_fallback(self, *, value_type: Type) -> Instance:
anonymous = self.as_anonymous()
return anonymous.fallback.copy_modified(args=[ # i.e. Mapping
anonymous.fallback.args[0], # i.e. str
value_type
])
return anonymous.fallback

def names_are_wider_than(self, other: 'TypedDictType') -> bool:
return len(other.items.keys() - self.items.keys()) == 0
Expand Down Expand Up @@ -1822,13 +1830,10 @@ def item_str(name: str, typ: str) -> str:
s = '{' + ', '.join(item_str(name, typ.accept(self))
for name, typ in t.items.items()) + '}'
prefix = ''
suffix = ''
if t.fallback and t.fallback.type:
if t.fallback.type.fullname() != 'typing.Mapping':
if t.fallback.type.fullname() != 'mypy_extensions._TypedDict':
prefix = repr(t.fallback.type.fullname()) + ', '
else:
suffix = ', fallback={}'.format(t.fallback.accept(self))
return 'TypedDict({}{}{})'.format(prefix, s, suffix)
return 'TypedDict({}{})'.format(prefix, s)

def visit_raw_literal_type(self, t: RawLiteralType) -> str:
return repr(t.value)
Expand Down
Loading