From cc9615ac34ea3b9c22cad06f64d49e3b36eaf729 Mon Sep 17 00:00:00 2001 From: Benjamin Smedberg Date: Fri, 28 Jan 2022 13:32:24 -0500 Subject: [PATCH 1/6] Disallow direct item access of NotRequired TypedDict properties: these should always be accessed through .get() because the keys may not be present. Fixes #12094 --- mypy/checkexpr.py | 11 ++++++-- mypy/checkmember.py | 2 +- mypy/checkpattern.py | 2 +- mypy/errorcodes.py | 3 +++ mypy/messages.py | 11 ++++++++ mypy/types.py | 3 +++ test-data/unit/check-typeddict.test | 40 ++++++++++++++++++++++++++--- 7 files changed, 64 insertions(+), 8 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9bf3ec3a4456..919c8cb5b969 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2932,6 +2932,9 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type: def visit_index_expr(self, e: IndexExpr) -> Type: """Type check an index expression (base[index]). + This function is only used for *expressions* (rvalues) not for setitem + statements (lvalues). + It may also represent type application. """ result = self.visit_index_expr_helper(e) @@ -2988,7 +2991,7 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr, else: return self.nonliteral_tuple_index_helper(left_type, index) elif isinstance(left_type, TypedDictType): - return self.visit_typeddict_index_expr(left_type, e.index) + return self.visit_typeddict_index_expr(left_type, e.index, is_expression=True) elif (isinstance(left_type, CallableType) and left_type.is_type_obj() and left_type.type_object().is_enum): return self.visit_enum_index_expr(left_type.type_object(), e.index, e) @@ -3081,7 +3084,9 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression, - local_errors: Optional[MessageBuilder] = None + local_errors: Optional[MessageBuilder] = None, + *, + is_expression: bool ) -> Type: local_errors = local_errors or self.msg if isinstance(index, (StrExpr, UnicodeExpr)): @@ -3113,6 +3118,8 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, local_errors.typeddict_key_not_found(td_type, key_name, index) return AnyType(TypeOfAny.from_error) else: + if is_expression and not td_type.is_required(key_name): + local_errors.typeddict_key_not_required(td_type, key_name, index) value_types.append(value_type) return make_simplified_union(value_types) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 1c66320bb562..fc1d790c7758 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -856,7 +856,7 @@ def analyze_typeddict_access(name: str, typ: TypedDictType, # Since we can get this during `a['key'] = ...` # it is safe to assume that the context is `IndexExpr`. item_type = mx.chk.expr_checker.visit_typeddict_index_expr( - typ, mx.context.index) + typ, mx.context.index, is_expression=False) else: # It can also be `a.__setitem__(...)` direct call. # In this case `item_type` can be `Any`, diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 2c40e856be88..327e43cf45ed 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -412,7 +412,7 @@ def get_mapping_item_type(self, mapping_type = get_proper_type(mapping_type) if isinstance(mapping_type, TypedDictType): result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr( - mapping_type, key, local_errors=local_errors) + mapping_type, key, local_errors=local_errors, is_expression=False) # If we can't determine the type statically fall back to treating it as a normal # mapping if local_errors.is_errors(): diff --git a/mypy/errorcodes.py b/mypy/errorcodes.py index ba716608ae56..e7c28dd3b84a 100644 --- a/mypy/errorcodes.py +++ b/mypy/errorcodes.py @@ -69,6 +69,9 @@ def __str__(self) -> str: TYPEDDICT_ITEM: Final = ErrorCode( "typeddict-item", "Check items when constructing TypedDict", "General" ) +TYPEDDICT_ITEM_ACCESS: Final = ErrorCode( + "typeddict-item-access", "Check item access when using TypedDict", "General" +) HAS_TYPE: Final = ErrorCode( "has-type", "Check that type of reference can be determined", "General" ) diff --git a/mypy/messages.py b/mypy/messages.py index 406237783cf1..12477c3a0cb2 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1276,6 +1276,17 @@ def typeddict_key_not_found( self.note("Did you mean {}?".format( pretty_seq(matches[:3], "or")), context, code=codes.TYPEDDICT_ITEM) + def typeddict_key_not_required( + self, + typ: TypedDictType, + item_name: str, + context: Context) -> None: + type_name: str = "" + if not typ.is_anonymous(): + type_name = format_type(typ) + " " + self.fail('TypedDict {}key "{}" is not required.'.format( + type_name, item_name), context, code=codes.TYPEDDICT_ITEM_ACCESS) + def typeddict_context_ambiguous( self, types: List[TypedDictType], diff --git a/mypy/types.py b/mypy/types.py index 1d7ab669a2d4..8d09179eb910 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1680,6 +1680,9 @@ def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str], def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_typeddict_type(self) + def is_required(self, key: str) -> bool: + return key in self.required_keys + def __hash__(self) -> int: return hash((frozenset(self.items.items()), self.fallback, frozenset(self.required_keys))) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index a9321826b3ba..ed71b5902c9d 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -751,7 +751,6 @@ def get_coordinate(p: TaggedPoint, key: str) -> Union[str, int]: return p[key] # E: TypedDict key must be a string literal; expected one of ("type", "x", "y") [builtins fixtures/dict.pyi] - -- Special Method: __setitem__ [case testCanSetItemOfTypedDictWithValidStringLiteralKeyAndCompatibleValueType] @@ -1048,7 +1047,8 @@ reveal_type(d.get('x', {})) \ reveal_type(d.get('x', None)) \ # N: Revealed type is "Union[TypedDict('__main__.C', {'a': builtins.int}), None]" reveal_type(d.get('x', {}).get('a')) # N: Revealed type is "Union[builtins.int, None]" -reveal_type(d.get('x', {})['a']) # N: Revealed type is "builtins.int" +reveal_type(d.get('x', {})['a']) # N: Revealed type is "builtins.int" \ + # E: TypedDict "C" key "a" is not required. [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] @@ -1100,8 +1100,10 @@ f(D(x='')) # E: Incompatible types (expression has type "str", TypedDict item "x from mypy_extensions import TypedDict D = TypedDict('D', {'x': int, 'y': str}, total=False) d: D -reveal_type(d['x']) # N: Revealed type is "builtins.int" -reveal_type(d['y']) # N: Revealed type is "builtins.str" +reveal_type(d['x']) # N: Revealed type is "builtins.int" \ + # E: TypedDict "D" key "x" is not required. +reveal_type(d['y']) # N: Revealed type is "builtins.str" \ + # E: TypedDict "D" key "y" is not required. reveal_type(d.get('x')) # N: Revealed type is "builtins.int" reveal_type(d.get('y')) # N: Revealed type is "builtins.str" [builtins fixtures/dict.pyi] @@ -2284,6 +2286,36 @@ from typing import NotRequired Foo = TypedDict("Foo", {"a.x": NotRequired[int]}) [typing fixtures/typing-typeddict.pyi] +[case testCannotGetItemNotRequired] +from typing import TypedDict +from typing import NotRequired +TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': NotRequired[int]}) +p: TaggedPoint +p['y'] # E: TypedDict "TaggedPoint" key "y" is not required. +[typing fixtures/typing-typeddict.pyi] + +[case testCannotGetItemNotTotal] +from typing import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': int}, total=False) +p: TaggedPoint +p['y'] # E: TypedDict "TaggedPoint" key "y" is not required. +[typing fixtures/typing-typeddict.pyi] + +[case testCanSetItemNotRequired] +from typing import TypedDict +from typing import NotRequired +TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': NotRequired[int]}) +p: TaggedPoint +p['y'] = 1 +[typing fixtures/typing-typeddict.pyi] + +[case testCanSetItemNotTotal] +from typing import TypedDict +TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': int}, total=False) +p: TaggedPoint +p['y'] = 1 +[typing fixtures/typing-typeddict.pyi] + -- Union dunders [case testTypedDictUnionGetItem] From 37f9e7b687532634bc02677234422139a2ec42e8 Mon Sep 17 00:00:00 2001 From: Benjamin Smedberg Date: Sat, 29 Jan 2022 16:30:34 -0500 Subject: [PATCH 2/6] Add stronger type-checking to TypedDict.get such that: 1. invalid keys type-check as the default value type (None or the provided default) 2. required keys type-check just to the declared type, never the fallback 3. optional keys type-check to the declared type or the default value type Fixes tests to match the better typecheck results. --- mypy/plugins/default.py | 88 +++++++++++++++++++---------- mypyc/test-data/run-misc.test | 8 ++- test-data/unit/check-literal.test | 26 +++++---- test-data/unit/check-narrowing.test | 9 ++- test-data/unit/check-typeddict.test | 66 +++++++++++++++++++--- test-data/unit/pythoneval.test | 2 +- 6 files changed, 143 insertions(+), 56 deletions(-) diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index c57c5f9a18d9..7ad6932c2ea8 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -238,37 +238,65 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: def typed_dict_get_callback(ctx: MethodContext) -> Type: """Infer a precise return type for TypedDict.get with literal first argument.""" - if (isinstance(ctx.type, TypedDictType) - and len(ctx.arg_types) >= 1 - and len(ctx.arg_types[0]) == 1): - keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) - if keys is None: - return ctx.default_return_type + if not ( + isinstance(ctx.type, TypedDictType) + and len(ctx.arg_types) >= 1 + and len(ctx.arg_types[0]) == 1 + ): + return ctx.default_return_type - output_types: List[Type] = [] - for key in keys: - value_type = get_proper_type(ctx.type.items.get(key)) - if value_type is None: - return ctx.default_return_type - - if len(ctx.arg_types) == 1: - output_types.append(value_type) - elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 - and len(ctx.args[1]) == 1): - default_arg = ctx.args[1][0] - if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 - and isinstance(value_type, TypedDictType)): - # Special case '{}' as the default for a typed dict type. - output_types.append(value_type.copy_modified(required_keys=set())) - else: - output_types.append(value_type) - output_types.append(ctx.arg_types[1][0]) - - if len(ctx.arg_types) == 1: - output_types.append(NoneType()) - - return make_simplified_union(output_types) - return ctx.default_return_type + keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) + if keys is None: + return ctx.default_return_type + + default_type: Optional[Type] + if len(ctx.arg_types) == 1: + default_type = None + elif len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1: + default_type = ctx.arg_types[1][0] + else: + default_type = ctx.default_return_type + + output_types: List[Type] = [] + + for key in keys: + value_type = get_proper_type(ctx.type.items.get(key)) + if value_type is None: + # It would be nice to issue a "TypedDict has no key {key}" failure here. However, + # we don't do this because in the case where you have a union of typeddicts, and + # one of them has the key but others don't, an error message is incorrect, and + # the plugin API has no mechanism to distinguish these cases. + output_types.append(default_type or NoneType()) + continue + + if ctx.type.is_required(key): + # Without unions we could issue an error for .get('required_key', default) because + # the default doesn't make sense. But because of unions, we don't do that. + output_types.append(value_type) + continue + + if default_type is None: + output_types.extend([ + value_type, + NoneType(), + ]) + continue + + # Special case '{}' as the default for a typed dict type. + if len(ctx.args[1]) == 1: + default_arg = ctx.args[1][0] + if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 + and isinstance(value_type, TypedDictType)): + + output_types.append(value_type.copy_modified(required_keys=set())) + continue + + output_types.extend([ + value_type, + default_type, + ]) + + return make_simplified_union(output_types) def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType: diff --git a/mypyc/test-data/run-misc.test b/mypyc/test-data/run-misc.test index 736169f95b82..0a4a5ca4b26a 100644 --- a/mypyc/test-data/run-misc.test +++ b/mypyc/test-data/run-misc.test @@ -679,6 +679,7 @@ TypeError 10 [case testClassBasedTypedDict] +[typing fixtures/typing-full.pyi] from typing_extensions import TypedDict class TD(TypedDict): @@ -709,8 +710,11 @@ def test_inherited_typed_dict() -> None: def test_non_total_typed_dict() -> None: d3 = TD3(c=3) d4 = TD4(a=1, b=2, c=3, d=4) - assert d3['c'] == 3 - assert d4['d'] == 4 + assert d3['c'] == 3 # type: ignore[typeddict-item-access] + assert d4['d'] == 4 # type: ignore[typeddict-item-access] + assert d3.get('c') == 3 + assert d3.get('d') == 4 + assert d3.get('z') is None [case testClassBasedNamedTuple] from typing import NamedTuple diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 37ae12419151..fbb1e4ff929a 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2222,12 +2222,14 @@ c_key: Literal["c"] d: Outer reveal_type(d[a_key]) # N: Revealed type is "builtins.int" -reveal_type(d[b_key]) # N: Revealed type is "builtins.str" +reveal_type(d[b_key]) # N: Revealed type is "builtins.str" \ + # E: TypedDict "Outer" key "b" is not required. +reveal_type(d.get(b_key)) # N: Revealed type is "builtins.str" d[c_key] # E: TypedDict "Outer" has no key "c" -reveal_type(d.get(a_key, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]" +reveal_type(d.get(a_key, u)) # N: Revealed type is "builtins.int" reveal_type(d.get(b_key, u)) # N: Revealed type is "Union[builtins.str, __main__.Unrelated]" -reveal_type(d.get(c_key, u)) # N: Revealed type is "builtins.object" +reveal_type(d.get(c_key, u)) # N: Revealed type is "__main__.Unrelated" reveal_type(d.pop(a_key)) # E: Key "a" of TypedDict "Outer" cannot be deleted \ # N: Revealed type is "builtins.int" @@ -2270,8 +2272,8 @@ u: Unrelated reveal_type(a[int_key_good]) # N: Revealed type is "builtins.int" reveal_type(b[int_key_good]) # N: Revealed type is "builtins.int" reveal_type(c[str_key_good]) # N: Revealed type is "builtins.int" -reveal_type(c.get(str_key_good, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]" -reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "builtins.object" +reveal_type(c.get(str_key_good, u)) # N: Revealed type is "builtins.int" +reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "__main__.Unrelated" a[int_key_bad] # E: Tuple index out of range b[int_key_bad] # E: Tuple index out of range @@ -2311,6 +2313,7 @@ tup2[idx_bad] # E: Tuple index out of range [out] [case testLiteralIntelligentIndexingTypedDictUnions] +# flags: --strict-optional from typing_extensions import Literal, Final from mypy_extensions import TypedDict @@ -2338,12 +2341,12 @@ bad_keys: Literal["a", "bad"] reveal_type(test[good_keys]) # N: Revealed type is "Union[__main__.A, __main__.B]" reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B]" -reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?, __main__.B]" +reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B]" reveal_type(test.pop(optional_keys)) # N: Revealed type is "Union[__main__.D, __main__.E]" reveal_type(test.pop(optional_keys, 3)) # N: Revealed type is "Union[__main__.D, __main__.E, Literal[3]?]" reveal_type(test.setdefault(good_keys, AAndB())) # N: Revealed type is "Union[__main__.A, __main__.B]" -reveal_type(test.get(bad_keys)) # N: Revealed type is "builtins.object*" -reveal_type(test.get(bad_keys, 3)) # N: Revealed type is "builtins.object" +reveal_type(test.get(bad_keys)) # N: Revealed type is "Union[__main__.A, None]" +reveal_type(test.get(bad_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?]" del test[optional_keys] @@ -2411,6 +2414,7 @@ UnicodeDict = TypedDict(b'UnicodeDict', {'key': int}) [typing fixtures/typing-medium.pyi] [case testLiteralIntelligentIndexingMultiTypedDict] +# flags: --strict-optional from typing import Union from typing_extensions import Literal from mypy_extensions import TypedDict @@ -2439,9 +2443,9 @@ x[bad_keys] # E: TypedDict "D1" has no key "d" \ reveal_type(x[good_keys]) # N: Revealed type is "Union[__main__.B, __main__.C]" reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C]" -reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, Literal[3]?, __main__.C]" -reveal_type(x.get(bad_keys)) # N: Revealed type is "builtins.object*" -reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "builtins.object" +reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, __main__.C]" +reveal_type(x.get(bad_keys)) # N: Revealed type is "Union[__main__.A, __main__.B, __main__.C, None, __main__.D]" +reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B, __main__.C, Literal[3]?, __main__.D]" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 16cdc69ec1b7..34e54c98e62c 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -310,17 +310,20 @@ class TypedDict2(TypedDict, total=False): key: Literal['B', 'C'] x: Union[TypedDict1, TypedDict2] -if x['key'] == 'A': + +# NOTE: we ignore typeddict-item-access errors here because the narrowing doesn't work with .get(). + +if x['key'] == 'A': # type: ignore[typeddict-item-access] reveal_type(x) # N: Revealed type is "TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]})" else: reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]" -if x['key'] == 'C': +if x['key'] == 'C': # type: ignore[typeddict-item-access] reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]" else: reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]" -if x['key'] == 'D': +if x['key'] == 'D': # type: ignore[typeddict-item-access] reveal_type(x) # E: Statement is unreachable else: reveal_type(x) # N: Revealed type is "Union[TypedDict('__main__.TypedDict1', {'key'?: Union[Literal['A'], Literal['C']]}), TypedDict('__main__.TypedDict2', {'key'?: Union[Literal['B'], Literal['C']]})]" diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index ed71b5902c9d..ae2898edb20f 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -966,15 +966,17 @@ if int(): [case testTypedDictGetMethod] # flags: --strict-optional -from mypy_extensions import TypedDict +from typing import TypedDict, NotRequired class A: pass -D = TypedDict('D', {'x': int, 'y': str}) +D = TypedDict('D', {'x': int, 'y': NotRequired[str]}) d: D -reveal_type(d.get('x')) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(d.get('x')) # N: Revealed type is "builtins.int" reveal_type(d.get('y')) # N: Revealed type is "Union[builtins.str, None]" -reveal_type(d.get('x', A())) # N: Revealed type is "Union[builtins.int, __main__.A]" +reveal_type(d.get('x', A())) # N: Revealed type is "builtins.int" reveal_type(d.get('x', 1)) # N: Revealed type is "builtins.int" reveal_type(d.get('y', None)) # N: Revealed type is "Union[builtins.str, None]" +reveal_type(d.get('y', 24)) # N: Revealed type is "Union[builtins.str, Literal[24]?]" +reveal_type(d.get('y', A())) # N: Revealed type is "Union[builtins.str, __main__.A]" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] @@ -988,7 +990,7 @@ d: D reveal_type(d.get('x', [])) # N: Revealed type is "builtins.list[builtins.int]" d.get('x', ['x']) # E: List item 0 has incompatible type "str"; expected "int" a = [''] -reveal_type(d.get('x', a)) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str*]]" +reveal_type(d.get('x', a)) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] @@ -1004,14 +1006,59 @@ d.get('x', 1, 2) # E: No overload variant of "get" of "Mapping" matches argument # N: Possible overload variants: \ # N: def get(self, k: str) -> object \ # N: def [V] get(self, k: str, default: Union[int, V]) -> object -x = d.get('z') -reveal_type(x) # N: Revealed type is "builtins.object*" s = '' y = d.get(s) reveal_type(y) # N: Revealed type is "builtins.object*" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testTypedDictGetRequiredKey] +from typing import TypedDict, NotRequired +D = TypedDict('D', {'x': int, 'y': NotRequired[int]}) +d: D +x = d.get('x') +reveal_type(x) # N: Revealed type is "builtins.int" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictGetNotRequiredKey] +# flags: --strict-optional +from typing import TypedDict, NotRequired +D = TypedDict('D', {'x': int, 'y': NotRequired[str]}) +d: D +y = d.get('y') +reveal_type(y) # N: Revealed type is "Union[builtins.str, None]" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnionGet] +# flags: --strict-optional +from typing import TypedDict, NotRequired, Union +A = TypedDict('A', {'m': int, 'n': NotRequired[str], 'p': int}) +B = TypedDict('B', {'m': int, 'o': str, 'p': str}) +v: Union[A, B] +m = v.get('m') +reveal_type(m) # N: Revealed type is "builtins.int" +n = v.get('n') +reveal_type(n) # N: Revealed type is "Union[builtins.str, None]" +o = v.get('o') +reveal_type(o) # N: Revealed type is "Union[None, builtins.str]" +p = v.get('p') +reveal_type(p) # N: Revealed type is "Union[builtins.int, builtins.str]" +z = v.get('z') +reveal_type(z) # N: Revealed type is "None" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictGetMissingKey] +from typing import TypedDict, NotRequired +D = TypedDict('D', {'x': int, 'y': NotRequired[int]}) +d: D +z = d.get('z') +reveal_type(z) # N: Revealed type is "None" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + [case testTypedDictMissingMethod] from mypy_extensions import TypedDict D = TypedDict('D', {'x': int, 'y': str}) @@ -1040,7 +1087,7 @@ p.get('x', 1 + 'y') # E: Unsupported operand types for + ("int" and "str") # flags: --strict-optional from mypy_extensions import TypedDict C = TypedDict('C', {'a': int}) -D = TypedDict('D', {'x': C, 'y': str}) +D = TypedDict('D', {'x': C, 'y': str}, total=False) d: D reveal_type(d.get('x', {})) \ # N: Revealed type is "TypedDict('__main__.C', {'a'?: builtins.int})" @@ -1711,6 +1758,7 @@ alias(s) [builtins fixtures/dict.pyi] [case testPluginUnionsOfTypedDicts] +# flags: --strict-optional from typing import Union from mypy_extensions import TypedDict @@ -1727,7 +1775,7 @@ td: Union[TDA, TDB] reveal_type(td.get('a')) # N: Revealed type is "builtins.int" reveal_type(td.get('b')) # N: Revealed type is "Union[builtins.str, builtins.int]" -reveal_type(td.get('c')) # N: Revealed type is "builtins.object*" +reveal_type(td.get('c')) # N: Revealed type is "Union[None, builtins.int]" reveal_type(td['a']) # N: Revealed type is "builtins.int" reveal_type(td['b']) # N: Revealed type is "Union[builtins.str, builtins.int]" diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 993af4ced61e..76c452828369 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1081,7 +1081,7 @@ reveal_type(d.get(s)) [out] _testTypedDictGet.py:7: note: Revealed type is "builtins.int" _testTypedDictGet.py:8: note: Revealed type is "builtins.str" -_testTypedDictGet.py:9: note: Revealed type is "builtins.object*" +_testTypedDictGet.py:9: note: Revealed type is "None" _testTypedDictGet.py:10: error: All overload variants of "get" of "Mapping" require at least one argument _testTypedDictGet.py:10: note: Possible overload variants: _testTypedDictGet.py:10: note: def get(self, key: str) -> object From 04b260ac132d7a377d55ad3fe2d6822d870f5208 Mon Sep 17 00:00:00 2001 From: Benjamin Smedberg Date: Tue, 1 Feb 2022 15:02:19 -0500 Subject: [PATCH 3/6] Narrow not-required fields into required fields when code does `'literal' in typed_dict`. --- mypy/checker.py | 93 +++++++++++++++++++++-------- test-data/unit/check-narrowing.test | 15 +++++ 2 files changed, 84 insertions(+), 24 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 2f99b9b4fece..62a84dc0c2dc 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4500,7 +4500,7 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM # types of literal string or enum expressions). operands = [collapse_walrus(x) for x in node.operands] - operand_types = [] + operand_types: List[Type] = [] narrowable_operand_index_to_hash = {} for i, expr in enumerate(operands): if expr not in type_map: @@ -4543,6 +4543,9 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM partial_type_maps = [] for operator, expr_indices in simplified_operator_list: + if_map: TypeMap = {} + else_map: TypeMap = {} + if operator in {'is', 'is not', '==', '!='}: # is_valid_target: # Controls which types we're allowed to narrow exprs to. Note that @@ -4578,8 +4581,6 @@ def has_no_custom_eq_checks(t: Type) -> bool: expr_types = [operand_types[i] for i in expr_indices] should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types)) - if_map: TypeMap = {} - else_map: TypeMap = {} if should_narrow_by_identity: if_map, else_map = self.refine_identity_comparison_expression( operands, @@ -4609,34 +4610,28 @@ def has_no_custom_eq_checks(t: Type) -> bool: elif operator in {'in', 'not in'}: assert len(expr_indices) == 2 left_index, right_index = expr_indices - if left_index not in narrowable_operand_index_to_hash: - continue - item_type = operand_types[left_index] - collection_type = operand_types[right_index] + left_is_narrowable = left_index in narrowable_operand_index_to_hash + right_is_narrowable = right_index in narrowable_operand_index_to_hash - # We only try and narrow away 'None' for now - if not is_optional(item_type): - continue + left_type = operand_types[left_index] + right_type = operand_types[right_index] - collection_item_type = get_proper_type(builtin_item_type(collection_type)) - if collection_item_type is None or is_optional(collection_item_type): - continue - if (isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == 'builtins.object'): - continue - if is_overlapping_erased_types(item_type, collection_item_type): - if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {} - else: - continue - else: - if_map = {} - else_map = {} + if left_is_narrowable: + narrowed_left_type = self.refine_optional_in(left_type, right_type) + if narrowed_left_type: + if_map = {operands[left_index]: narrowed_left_type} + + elif right_is_narrowable: + narrowed_right_type = self.refine_typeddict_in(left_type, right_type) + if narrowed_right_type: + if_map = {operands[right_index]: narrowed_right_type} if operator in {'is not', '!=', 'not in'}: if_map, else_map = else_map, if_map - partial_type_maps.append((if_map, else_map)) + if if_map != {} or else_map != {}: + partial_type_maps.append((if_map, else_map)) return reduce_conditional_maps(partial_type_maps) elif isinstance(node, AssignmentExpr): @@ -4865,6 +4860,56 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: expr = parent_expr expr_type = output[parent_expr] = make_simplified_union(new_parent_types) + def refine_optional_in(self, + item_type: Type, + collection_type: Type, + ) -> Optional[Type]: + """ + Check whether a condition `optional_item in collection_type` can narrow away Optional. + + Returns the narrowed item_type, if any narrowing is appropriate. + """ + if not is_optional(item_type): + return None + + collection_item_type = get_proper_type(builtin_item_type(collection_type)) + if collection_item_type is None or is_optional(collection_item_type): + return None + + if (isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == 'builtins.object'): + return None + if is_overlapping_erased_types(item_type, collection_item_type): + return remove_optional(item_type) + return None + + def refine_typeddict_in(self, + literal_type: Type, + collection_type: Type, + ) -> Optional[Type]: + """ + Check whether a condition `'literal' in typeddict` can narrow a non-required ite + into a required item. + + Returns the narrowed collection_type, if any narrowing is appropriate. + """ + collection_type = get_proper_type(collection_type) + if not isinstance(collection_type, TypedDictType): + return None + + literals = try_getting_str_literals_from_type(literal_type) + if literals is None or len(literals) > 1: + return None + + key = literals[0] + if key not in collection_type.items: + return None + + if collection_type.is_required(key): + return None + + return collection_type.copy_modified(required_keys=collection_type.required_keys | {key}) + def refine_identity_comparison_expression(self, operands: List[Expression], operand_types: List[Type], diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 34e54c98e62c..847378404e9d 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1151,6 +1151,21 @@ def f(d: Union[Foo, Bar]) -> None: reveal_type(d) # N: Revealed type is "TypedDict('__main__.Foo', {'tag': Literal[__main__.E.FOO], 'x': builtins.int})" [builtins fixtures/dict.pyi] +[case testNarrowingTypedDictNotRequired] +# flags: --strict-optional +from typing_extensions import TypedDict + +class Foo(TypedDict, total=False): + a: str + +foo: Foo +if 'a' in foo: + reveal_type(foo['a']) # N: Revealed type is "builtins.str" +else: + reveal_type(foo['a']) # N: Revealed type is "builtins.str" \ + # E: TypedDict "Foo" key "a" is not required. +[builtins fixtures/dict.pyi] + [case testNarrowingUsingMetaclass] # flags: --strict-optional from typing import Type From 2c38fd3ea877ed7e940d63239af0f80d112f8735 Mon Sep 17 00:00:00 2001 From: Benjamin Smedberg Date: Tue, 1 Feb 2022 15:23:23 -0500 Subject: [PATCH 4/6] Add documentation for the stricter item access on TypedDicts. Also document the new and more-ergonomic way to mix required and not-required items in a TypedDict using the NotRequired annotation. --- docs/source/more_types.rst | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/docs/source/more_types.rst b/docs/source/more_types.rst index 82a6568afcb2..96de88ec0de5 100644 --- a/docs/source/more_types.rst +++ b/docs/source/more_types.rst @@ -1029,11 +1029,9 @@ Sometimes you want to allow keys to be left out when creating a options['language'] = 'en' You may need to use :py:meth:`~dict.get` to access items of a partial (non-total) -``TypedDict``, since indexing using ``[]`` could fail at runtime. -However, mypy still lets use ``[]`` with a partial ``TypedDict`` -- you -just need to be careful with it, as it could result in a :py:exc:`KeyError`. -Requiring :py:meth:`~dict.get` everywhere would be too cumbersome. (Note that you -are free to use :py:meth:`~dict.get` with total ``TypedDict``\s as well.) +``TypedDict``, since indexing using ``[]`` could fail at runtime. By default +mypy will issue an error for this case; it is possible to disable this check +by adding "typeddict-item-access" to the :confval:`disable_error_code` config option. Keys that aren't required are shown with a ``?`` in error messages: @@ -1120,18 +1118,15 @@ Now ``BookBasedMovie`` has keys ``name``, ``year`` and ``based_on``. Mixing required and non-required items -------------------------------------- -In addition to allowing reuse across ``TypedDict`` types, inheritance also allows -you to mix required and non-required (using ``total=False``) items -in a single ``TypedDict``. Example: +When a ``TypedDict`` has a mix of items that are required and not required, +the ``NotRequired`` type annotation can be used to specify this for each field: .. code-block:: python - class MovieBase(TypedDict): + class Movie(TypedDict): name: str year: int - - class Movie(MovieBase, total=False): - based_on: str + based_on: NotRequired[str] Now ``Movie`` has required keys ``name`` and ``year``, while ``based_on`` can be left out when constructing an object. A ``TypedDict`` with a mix of required From e53e1b7e87973382e79e1cbb5df9b9328aba6e4e Mon Sep 17 00:00:00 2001 From: Benjamin Smedberg Date: Wed, 2 Feb 2022 09:07:04 -0500 Subject: [PATCH 5/6] Rewording of error message per review by davidfstr. --- mypy/messages.py | 2 +- test-data/unit/check-literal.test | 2 +- test-data/unit/check-narrowing.test | 2 +- test-data/unit/check-typeddict.test | 10 +++++----- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mypy/messages.py b/mypy/messages.py index 12477c3a0cb2..898edef1a59b 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1284,7 +1284,7 @@ def typeddict_key_not_required( type_name: str = "" if not typ.is_anonymous(): type_name = format_type(typ) + " " - self.fail('TypedDict {}key "{}" is not required.'.format( + self.fail('TypedDict {}key "{}" is not required and might not be present.'.format( type_name, item_name), context, code=codes.TYPEDDICT_ITEM_ACCESS) def typeddict_context_ambiguous( diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index fbb1e4ff929a..008b746c259c 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2223,7 +2223,7 @@ d: Outer reveal_type(d[a_key]) # N: Revealed type is "builtins.int" reveal_type(d[b_key]) # N: Revealed type is "builtins.str" \ - # E: TypedDict "Outer" key "b" is not required. + # E: TypedDict "Outer" key "b" is not required and might not be present. reveal_type(d.get(b_key)) # N: Revealed type is "builtins.str" d[c_key] # E: TypedDict "Outer" has no key "c" diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 847378404e9d..71feef3237a1 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1163,7 +1163,7 @@ if 'a' in foo: reveal_type(foo['a']) # N: Revealed type is "builtins.str" else: reveal_type(foo['a']) # N: Revealed type is "builtins.str" \ - # E: TypedDict "Foo" key "a" is not required. + # E: TypedDict "Foo" key "a" is not required and might not be present. [builtins fixtures/dict.pyi] [case testNarrowingUsingMetaclass] diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index ae2898edb20f..5d141d9de17c 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -1095,7 +1095,7 @@ reveal_type(d.get('x', None)) \ # N: Revealed type is "Union[TypedDict('__main__.C', {'a': builtins.int}), None]" reveal_type(d.get('x', {}).get('a')) # N: Revealed type is "Union[builtins.int, None]" reveal_type(d.get('x', {})['a']) # N: Revealed type is "builtins.int" \ - # E: TypedDict "C" key "a" is not required. + # E: TypedDict "C" key "a" is not required and might not be present. [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] @@ -1148,9 +1148,9 @@ from mypy_extensions import TypedDict D = TypedDict('D', {'x': int, 'y': str}, total=False) d: D reveal_type(d['x']) # N: Revealed type is "builtins.int" \ - # E: TypedDict "D" key "x" is not required. + # E: TypedDict "D" key "x" is not required and might not be present. reveal_type(d['y']) # N: Revealed type is "builtins.str" \ - # E: TypedDict "D" key "y" is not required. + # E: TypedDict "D" key "y" is not required and might not be present. reveal_type(d.get('x')) # N: Revealed type is "builtins.int" reveal_type(d.get('y')) # N: Revealed type is "builtins.str" [builtins fixtures/dict.pyi] @@ -2339,14 +2339,14 @@ from typing import TypedDict from typing import NotRequired TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': NotRequired[int]}) p: TaggedPoint -p['y'] # E: TypedDict "TaggedPoint" key "y" is not required. +p['y'] # E: TypedDict "TaggedPoint" key "y" is not required and might not be present. [typing fixtures/typing-typeddict.pyi] [case testCannotGetItemNotTotal] from typing import TypedDict TaggedPoint = TypedDict('TaggedPoint', {'x': int, 'y': int}, total=False) p: TaggedPoint -p['y'] # E: TypedDict "TaggedPoint" key "y" is not required. +p['y'] # E: TypedDict "TaggedPoint" key "y" is not required and might not be present. [typing fixtures/typing-typeddict.pyi] [case testCanSetItemNotRequired] From 37aabde925d16336be06d04c7892fa35fd01f29a Mon Sep 17 00:00:00 2001 From: Benjamin Smedberg Date: Wed, 2 Feb 2022 09:09:41 -0500 Subject: [PATCH 6/6] Rename parameter per review suggest from davidfstr --- mypy/checkexpr.py | 6 +++--- mypy/checkmember.py | 2 +- mypy/checkpattern.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 919c8cb5b969..a99bdb3f539a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2991,7 +2991,7 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr, else: return self.nonliteral_tuple_index_helper(left_type, index) elif isinstance(left_type, TypedDictType): - return self.visit_typeddict_index_expr(left_type, e.index, is_expression=True) + return self.visit_typeddict_index_expr(left_type, e.index, is_rvalue=True) elif (isinstance(left_type, CallableType) and left_type.is_type_obj() and left_type.type_object().is_enum): return self.visit_enum_index_expr(left_type.type_object(), e.index, e) @@ -3086,7 +3086,7 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression, local_errors: Optional[MessageBuilder] = None, *, - is_expression: bool + is_rvalue: bool ) -> Type: local_errors = local_errors or self.msg if isinstance(index, (StrExpr, UnicodeExpr)): @@ -3118,7 +3118,7 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, local_errors.typeddict_key_not_found(td_type, key_name, index) return AnyType(TypeOfAny.from_error) else: - if is_expression and not td_type.is_required(key_name): + if is_rvalue and not td_type.is_required(key_name): local_errors.typeddict_key_not_required(td_type, key_name, index) value_types.append(value_type) return make_simplified_union(value_types) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index fc1d790c7758..ed9103c3972d 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -856,7 +856,7 @@ def analyze_typeddict_access(name: str, typ: TypedDictType, # Since we can get this during `a['key'] = ...` # it is safe to assume that the context is `IndexExpr`. item_type = mx.chk.expr_checker.visit_typeddict_index_expr( - typ, mx.context.index, is_expression=False) + typ, mx.context.index, is_rvalue=False) else: # It can also be `a.__setitem__(...)` direct call. # In this case `item_type` can be `Any`, diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 327e43cf45ed..ec10e913cb3c 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -412,7 +412,7 @@ def get_mapping_item_type(self, mapping_type = get_proper_type(mapping_type) if isinstance(mapping_type, TypedDictType): result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr( - mapping_type, key, local_errors=local_errors, is_expression=False) + mapping_type, key, local_errors=local_errors, is_rvalue=False) # If we can't determine the type statically fall back to treating it as a normal # mapping if local_errors.is_errors():