diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d465ae485b3a..fb23c3590088 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -9,7 +9,7 @@ PartialType, DeletedType, UnboundType, UninhabitedType, TypeType, true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike, get_typ_args, set_typ_args, -) + TypedDictGetFunction) from mypy.nodes import ( NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, @@ -347,6 +347,24 @@ def check_call(self, callee: Type, args: List[Expression], callee.type_object().name(), type.abstract_attributes, context) + if isinstance(callee, TypedDictGetFunction): + if 1 <= len(args) <= 2 and isinstance(args[0], (StrExpr, UnicodeExpr)): + return_type = self.get_typeddict_index_type(callee.typed_dict, args[0]) + arg_types = callee.arg_types + if len(args) == 1: + return_type = UnionType.make_union([ + return_type, NoneTyp()]) + elif isinstance(return_type, TypedDictType) and len(callee.arg_types) == 2: + # Explicitly set the type of the default parameter to + # Union[typing.Mapping, ] in cases where the return value + # is a typed dict. This special case allows for chaining of `get` methods + # when accessing elements deep within nested dictionaries in a safe and + # concise way without having to set up exception handlers. + arg_types = [callee.arg_types[0], + UnionType.make_union([return_type, + self.named_type('typing.Mapping')])] + callee = callee.copy_modified(ret_type=return_type, arg_types=arg_types) + formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, callee.arg_kinds, callee.arg_names, @@ -1484,11 +1502,13 @@ def _get_value(self, index: Expression) -> Optional[int]: return None def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type: + return self.get_typeddict_index_type(td_type, index) + + def get_typeddict_index_type(self, td_type: TypedDictType, index: Expression) -> Type: if not isinstance(index, (StrExpr, UnicodeExpr)): self.msg.typeddict_item_name_must_be_string_literal(td_type, index) return AnyType() item_name = index.value - item_type = td_type.items.get(item_name) if item_type is None: self.msg.typeddict_item_name_not_found(td_type, item_name, index) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 31bd699a47e8..dd729fa61dde 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -5,8 +5,8 @@ from mypy.types import ( Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef, Overloaded, TypeVarType, UnionType, PartialType, - DeletedType, NoneTyp, TypeType, function_type -) + DeletedType, NoneTyp, TypeType, function_type, + TypedDictGetFunction) from mypy.nodes import ( TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr, ARG_POS, ARG_STAR, ARG_STAR2, @@ -120,9 +120,12 @@ def analyze_member_access(name: str, original_type=original_type, chk=chk) elif isinstance(typ, TypedDictType): # Actually look up from the fallback instance type. - return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super, - is_operator, builtin_type, not_ready_callback, msg, - original_type=original_type, chk=chk) + result = analyze_member_access(name, typ.fallback, node, is_lvalue, is_super, + is_operator, builtin_type, not_ready_callback, msg, + original_type=original_type, chk=chk) + if name == 'get' and isinstance(result, CallableType): + result = TypedDictGetFunction(typ, result) + return result elif isinstance(typ, FunctionLike) and typ.is_type_obj(): # Class attribute. # TODO super? diff --git a/mypy/types.py b/mypy/types.py index 32d7c8340a6c..fcc5cb036175 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -980,6 +980,26 @@ def zipall(self, right: 'TypedDictType') \ yield (item_name, None, right_item_type) +class TypedDictGetFunction(CallableType): + """A special callable type containing a reference to the TypedDict `get` callable instance. + This is needed to delay determining the signature of a TypedDict's `get` method until the + method is actually called. This allows `get` to behave just as indexing into the TypedDict + would. + + This is not a real type, but is needed to allow TypedDict.get to behave as expected. + """ + def __init__(self, typed_dict: TypedDictType, fallback_callable: CallableType) -> None: + super().__init__(fallback_callable.arg_types, fallback_callable.arg_kinds, + fallback_callable.arg_names, fallback_callable.ret_type, + fallback_callable.fallback, fallback_callable.name, + fallback_callable.definition, fallback_callable.variables, + fallback_callable.line, fallback_callable.column, + fallback_callable.is_ellipsis_args, fallback_callable.implicit, + fallback_callable.is_classmethod_class, fallback_callable.special_sig) + self.typed_dict = typed_dict + self.fallback_callable = fallback_callable + + class StarType(Type): """The star type *type_parameter. diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 424c8b2b84e0..95564e669728 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -431,6 +431,90 @@ def set_coordinate(p: TaggedPoint, key: str, value: int) -> None: -- Special Method: get +[case testCanUseGetMethodWithStringLiteralKey] +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +p = TaggedPoint(type='2d', x=42, y=1337) +reveal_type(p.get('type')) # E: Revealed type is 'Union[builtins.str, builtins.None]' +reveal_type(p.get('x')) # E: Revealed type is 'Union[builtins.int, builtins.None]' +reveal_type(p.get('y', 0)) # E: Revealed type is 'builtins.int' +[builtins fixtures/dict.pyi] + +[case testDefaultParameterStillTypeChecked] +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +p = TaggedPoint(type='2d', x=42, y=1337) +p.get('x', 1 + 'y') # E: Unsupported operand types for + ("int" and "str") +[builtins fixtures/dict.pyi] + +[case testCannotGetMethodWithInvalidStringLiteralKey] +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +p = TaggedPoint(type='2d', x=42, y=1337) +p.get('z') # E: 'z' is not a valid item name; expected one of ['type', 'x', 'y'] +[builtins fixtures/dict.pyi] + +[case testGetMethodWithVariableKeyFallsBack] +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +p = TaggedPoint(type='2d', x=42, y=1337) +key = 'type' +reveal_type(p.get(key)) # E: Revealed type is 'builtins.object*' +[builtins fixtures/dict.pyi] + +[case testChainedGetMethodWithDictFallback] +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +PointSet = TypedDict('PointSet', {'first_point': TaggedPoint}) +p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337)) +reveal_type(p.get('first_point', {}).get('x', 0)) # E: Revealed type is 'builtins.int' +[builtins fixtures/dict.pyi] + +[case testGetMethodInvalidDefaultType] +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +PointSet = TypedDict('PointSet', {'first_point': TaggedPoint}) +p = PointSet(first_point=TaggedPoint(type='2d', x=42, y=1337)) +p.get('first_point', 32) # E: Argument 2 to "get" of "Mapping" has incompatible type "int"; expected "Union[TypedDict(type=str, x=int, y=int), Mapping]" +[builtins fixtures/dict.pyi] + +[case testGetMethodOnList] +from typing import List +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +PointSet = TypedDict('PointSet', {'points': List[TaggedPoint]}) +p = PointSet(points=[TaggedPoint(type='2d', x=42, y=1337)]) +reveal_type(p.get('points', [])) # E: Revealed type is 'builtins.list[TypedDict(type=builtins.str, x=builtins.int, y=builtins.int, _fallback=__main__.TaggedPoint)]' +[builtins fixtures/dict.pyi] + +[case testGetMethodWithListOfStrUnifies] +from typing import List +from mypy_extensions import TypedDict +Items = TypedDict('Items', {'name': str, 'values': List[str]}) +def foo(i: Items) -> None: + reveal_type(i.get('values', [])) # E: Revealed type is 'builtins.list[builtins.str]' +[builtins fixtures/dict.pyi] + +[case testDictGetMethodStillCallable] +from typing import Callable +from mypy_extensions import TypedDict +Point = TypedDict('Point', {'x': int, 'y': int}) +p = Point(x=42, y=13) +def invoke_method(method: Callable[[str, int], int]) -> None: + pass +invoke_method(p.get) +[builtins fixtures/dict.pyi] + +[case testDictGetMethodStillCallableWithObject] +from typing import Callable +from mypy_extensions import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'type': str, 'x': int, 'y': int}) +p = TaggedPoint(type='2d', x=42, y=1337) +def invoke_method(method: Callable[..., object]) -> None: + pass +invoke_method(p.get) +[builtins fixtures/dict.pyi] + -- TODO: Implement support for these cases: --[case testGetOfTypedDictWithValidStringLiteralKeyReturnsPreciseType] --[case testGetOfTypedDictWithInvalidStringLiteralKeyIsError] diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index 5a7886439692..877e4ef8c230 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -18,6 +18,7 @@ class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]): def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass def __setitem__(self, k: KT, v: VT) -> None: pass def __iter__(self) -> Iterator[KT]: pass + def get(self, k: KT, default: VT=None) -> VT: pass def update(self, a: Mapping[KT, VT]) -> None: pass class int: # for convenience diff --git a/test-data/unit/lib-stub/typing.pyi b/test-data/unit/lib-stub/typing.pyi index 77a7b349e4cd..b7f59ec07247 100644 --- a/test-data/unit/lib-stub/typing.pyi +++ b/test-data/unit/lib-stub/typing.pyi @@ -78,7 +78,9 @@ class Sequence(Iterable[T], Generic[T]): @abstractmethod def __getitem__(self, n: Any) -> T: pass -class Mapping(Generic[T, U]): pass +class Mapping(Generic[T, U]): + @abstractmethod + def get(self, k: T, default: U=None) -> U: pass class MutableMapping(Generic[T, U]): pass