From 82f330a5f0d47af54fafe7a12b624f65fd9b3182 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 11 Jun 2023 21:50:12 +0100 Subject: [PATCH 01/11] Start working --- mypy/checkexpr.py | 163 +++++++++++++++++++--------- mypy/main.py | 8 ++ mypy/options.py | 4 + mypy/plugins/default.py | 15 +++ mypy/types.py | 4 + test-data/unit/check-typeddict.test | 73 +++++++++++++ 6 files changed, 217 insertions(+), 50 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index cd0ff1100183..2e0be8e5fa5c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4,8 +4,9 @@ import itertools import time +from collections import defaultdict from contextlib import contextmanager -from typing import Callable, ClassVar, Iterator, List, Optional, Sequence, cast +from typing import Callable, ClassVar, Iterable, Iterator, List, Optional, Sequence, cast from typing_extensions import Final, TypeAlias as _TypeAlias, overload import mypy.checker @@ -685,74 +686,130 @@ def check_typeddict_call( context: Context, orig_callee: Type | None, ) -> Type: - if args and all([ak == ARG_NAMED for ak in arg_kinds]): + if args and all([ak in (ARG_NAMED, ARG_STAR2) for ak in arg_kinds]): # ex: Point(x=42, y=1337) - assert all(arg_name is not None for arg_name in arg_names) - item_names = cast(List[str], arg_names) - item_args = args - return self.check_typeddict_call_with_kwargs( - callee, dict(zip(item_names, item_args)), context, orig_callee - ) + kwargs = zip([StrExpr(n) if n is not None else None for n in arg_names], args) + result = self.validate_typeddict_kwargs(kwargs=kwargs, callee=callee) + if result is not None: + validated_kwargs, always_present_keys = result + return self.check_typeddict_call_with_kwargs( + callee, validated_kwargs, context, orig_callee, always_present_keys + ) + return AnyType(TypeOfAny.from_error) if len(args) == 1 and arg_kinds[0] == ARG_POS: unique_arg = args[0] if isinstance(unique_arg, DictExpr): # ex: Point({'x': 42, 'y': 1337}) return self.check_typeddict_call_with_dict( - callee, unique_arg, context, orig_callee + callee, unique_arg.items, context, orig_callee ) if isinstance(unique_arg, CallExpr) and isinstance(unique_arg.analyzed, DictExpr): # ex: Point(dict(x=42, y=1337)) return self.check_typeddict_call_with_dict( - callee, unique_arg.analyzed, context, orig_callee + callee, unique_arg.analyzed.items, context, orig_callee ) if not args: # ex: EmptyDict() - return self.check_typeddict_call_with_kwargs(callee, {}, context, orig_callee) + return self.check_typeddict_call_with_kwargs(callee, {}, context, orig_callee, set()) self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context) return AnyType(TypeOfAny.from_error) - def validate_typeddict_kwargs(self, kwargs: DictExpr) -> dict[str, Expression] | None: - item_args = [item[1] for item in kwargs.items] - - item_names = [] # List[str] - for item_name_expr, item_arg in kwargs.items: - literal_value = None + def validate_typeddict_kwargs( + self, kwargs: Iterable[tuple[Expression | None, Expression]], callee: TypedDictType + ) -> tuple[dict[str, list[Expression]], set[str]] | None: + result = defaultdict(list) + always_present_keys = set() + for item_name_expr, item_arg in kwargs: if item_name_expr: key_type = self.accept(item_name_expr) values = try_getting_str_literals(item_name_expr, key_type) + literal_value = None if values and len(values) == 1: literal_value = values[0] - if literal_value is None: - key_context = item_name_expr or item_arg - self.chk.fail( - message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, - key_context, - code=codes.LITERAL_REQ, - ) - return None + if literal_value is None: + key_context = item_name_expr or item_arg + self.chk.fail( + message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, + key_context, + code=codes.LITERAL_REQ, + ) + return None + else: + result[literal_value] = [item_arg] + always_present_keys.add(literal_value) else: - item_names.append(literal_value) - return dict(zip(item_names, item_args)) + with self.chk.local_type_map(), self.msg.filter_errors(): + inferred = get_proper_type(self.accept(item_arg, type_context=callee)) + if isinstance(inferred, TypedDictType): + possible_tds = [inferred] + elif isinstance(inferred, UnionType): + possible_tds = [] + for item in get_proper_types(inferred.relevant_items()): + if isinstance(item, TypedDictType): + possible_tds.append(item) + else: + self.chk.fail("Bad star", item_arg) + return None + else: + self.chk.fail("Bad star", item_arg) + return None + all_keys: set[str] = set() + for td in possible_tds: + all_keys |= td.items.keys() + for key in all_keys: + arg = TempNode( + UnionType.make_union( + [td.items[key] for td in possible_tds if key in td.items] + ) + ) + arg.set_line(item_arg) + if all(key in td.required_keys for td in possible_tds): + always_present_keys.add(key) + if result[key]: + # TODO: stricter checks with strict flag + first = result[key][0] + if isinstance(first, TempNode): + result[key] = [arg] + else: + result[key] = [first, arg] + else: + result[key] = [arg] + else: + result[key].append(arg) + return result, always_present_keys def match_typeddict_call_with_dict( - self, callee: TypedDictType, kwargs: DictExpr, context: Context + self, + callee: TypedDictType, + kwargs: list[tuple[Expression | None, Expression]], + context: Context, ) -> bool: - validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) - if validated_kwargs is not None: + result = self.validate_typeddict_kwargs(kwargs=kwargs, callee=callee) + if result is not None: + validated_kwargs, _ = result return callee.required_keys <= set(validated_kwargs.keys()) <= set(callee.items.keys()) else: return False def check_typeddict_call_with_dict( - self, callee: TypedDictType, kwargs: DictExpr, context: Context, orig_callee: Type | None + self, + callee: TypedDictType, + kwargs: list[tuple[Expression | None, Expression]], + context: Context, + orig_callee: Type | None, ) -> Type: - validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) - if validated_kwargs is not None: + result = self.validate_typeddict_kwargs(kwargs=kwargs, callee=callee) + if result is not None: + validated_kwargs, always_present_keys = result return self.check_typeddict_call_with_kwargs( - callee, kwargs=validated_kwargs, context=context, orig_callee=orig_callee + callee, + kwargs=validated_kwargs, + context=context, + orig_callee=orig_callee, + always_present_keys=always_present_keys, ) else: return AnyType(TypeOfAny.from_error) @@ -793,12 +850,15 @@ def typeddict_callable_from_context(self, callee: TypedDictType) -> CallableType def check_typeddict_call_with_kwargs( self, callee: TypedDictType, - kwargs: dict[str, Expression], + kwargs: dict[str, list[Expression]], context: Context, orig_callee: Type | None, + always_present_keys: set[str], ) -> Type: actual_keys = kwargs.keys() - if not (callee.required_keys <= actual_keys <= callee.items.keys()): + if not ( + callee.required_keys <= always_present_keys and actual_keys <= callee.items.keys() + ): expected_keys = [ key for key in callee.items.keys() @@ -829,7 +889,7 @@ def check_typeddict_call_with_kwargs( with self.msg.filter_errors(), self.chk.local_type_map(): orig_ret_type, _ = self.check_callable_call( infer_callee, - list(kwargs.values()), + [args[0] for args in kwargs.values()], [ArgKind.ARG_NAMED] * len(kwargs), context, list(kwargs.keys()), @@ -846,17 +906,18 @@ def check_typeddict_call_with_kwargs( for item_name, item_expected_type in ret_type.items.items(): if item_name in kwargs: - item_value = kwargs[item_name] - self.chk.check_simple_assignment( - lvalue_type=item_expected_type, - rvalue=item_value, - context=item_value, - msg=ErrorMessage( - message_registry.INCOMPATIBLE_TYPES.value, code=codes.TYPEDDICT_ITEM - ), - lvalue_name=f'TypedDict item "{item_name}"', - rvalue_name="expression", - ) + item_values = kwargs[item_name] + for item_value in item_values: + self.chk.check_simple_assignment( + lvalue_type=item_expected_type, + rvalue=item_value, + context=item_value, + msg=ErrorMessage( + message_registry.INCOMPATIBLE_TYPES.value, code=codes.TYPEDDICT_ITEM + ), + lvalue_name=f'TypedDict item "{item_name}"', + rvalue_name="expression", + ) return orig_ret_type @@ -4327,7 +4388,7 @@ def check_typeddict_literal_in_context( self, e: DictExpr, typeddict_context: TypedDictType ) -> Type: orig_ret_type = self.check_typeddict_call_with_dict( - callee=typeddict_context, kwargs=e, context=e, orig_callee=None + callee=typeddict_context, kwargs=e.items, context=e, orig_callee=None ) ret_type = get_proper_type(orig_ret_type) if isinstance(ret_type, TypedDictType): @@ -4427,7 +4488,9 @@ def find_typeddict_context( for item in context.items: item_contexts = self.find_typeddict_context(item, dict_expr) for item_context in item_contexts: - if self.match_typeddict_call_with_dict(item_context, dict_expr, dict_expr): + if self.match_typeddict_call_with_dict( + item_context, dict_expr.items, dict_expr + ): items.append(item_context) return items # No TypedDict type in context. diff --git a/mypy/main.py b/mypy/main.py index 81a0a045745b..08ceecfb0572 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -833,6 +833,14 @@ def add_invertible_flag( group=strictness_group, ) + add_invertible_flag( + "--strict-typeddict-update", + default=False, + strict_flag=True, + help="Disallow partial overlap in TypedDict update (including ** in constructor)", + group=strictness_group, + ) + strict_help = "Strict mode; enables the following flags: {}".format( ", ".join(strict_flag_names) ) diff --git a/mypy/options.py b/mypy/options.py index 45591597ba69..a4601af7cea2 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -51,6 +51,7 @@ class BuildType: "strict_concatenate", "strict_equality", "strict_optional", + "strict_typeddict_update", "warn_no_return", "warn_return_any", "warn_unreachable", @@ -203,6 +204,9 @@ def __init__(self) -> None: # Make arguments prepended via Concatenate be truly positional-only. self.strict_concatenate = False + # Disallow partial overlap in TypedDict update (including ** in constructor). + self.strict_typeddict_update = False + # Report an error for any branches inferred to be unreachable as a result of # type analysis. self.warn_unreachable = False diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 500eef76a9d9..8d2ed9fd960c 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -402,6 +402,21 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: assert isinstance(arg_type, TypedDictType) arg_type = arg_type.as_anonymous() arg_type = arg_type.copy_modified(required_keys=set()) + if ctx.args and ctx.args[0]: + with ctx.api.msg.filter_errors(): + inferred = get_proper_type( + ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type) + ) + # TODO: unions + if isinstance(inferred, TypedDictType): + arg_type = arg_type.copy_modified( + # TODO: extra keys + required_keys=arg_type.required_keys + | inferred.required_keys + ) + if not ctx.api.options.strict_typeddict_update: + # TODO: extra names + arg_type = arg_type.copy_modified(item_names=list(inferred.items)) return signature.copy_modified(arg_types=[arg_type]) return signature diff --git a/mypy/types.py b/mypy/types.py index 5fbdd385826c..127a7d7a8355 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2433,6 +2433,7 @@ def copy_modified( *, fallback: Instance | None = None, item_types: list[Type] | None = None, + item_names: list[str] | None = None, required_keys: set[str] | None = None, ) -> TypedDictType: if fallback is None: @@ -2443,6 +2444,9 @@ def copy_modified( items = dict(zip(self.items, item_types)) if required_keys is None: required_keys = self.required_keys + if item_names is not None: + items = {k: v for (k, v) in items.items() if k in item_names} + required_keys &= set(item_names) return TypedDictType(items, required_keys, fallback, self.line, self.column) def create_anonymous_fallback(self) -> Instance: diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index fc487d2d553d..88a6d8b3d8bb 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2885,3 +2885,76 @@ d: A d[''] # E: TypedDict "A" has no key "" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + +[case testTypedDictFlexibleUpdate] +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +B = TypedDict("B", {"foo": int}) + +a = A({"foo": 1, "bar": 2}) +b = B({"foo": 2}) +a.update({"foo": 2}) +a.update(b) +a.update(a) +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictStrictUpdate] +# flags: --strict-typeddict-update +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +B = TypedDict("B", {"foo": int}) + +a = A({"foo": 1, "bar": 2}) +b = B({"foo": 2}) +a.update({"foo": 2}) # OK +a.update(b) # E: Argument 1 to "update" of "TypedDict" has incompatible type "B"; expected "TypedDict({'foo': int, 'bar'?: int})" +a.update(a) # OK +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackSame] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +foo1: Foo = {'a': 1, 'b': 1} +foo2: Foo = {**foo1, 'b': 2} +foo3 = Foo(**foo1, b=2) +foo4 = Foo({**foo1, 'b': 2}) +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackCompatible] +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: int + +foo: Foo = {'a': 1} +bar: Bar = {**foo, 'b': 2} +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackIncompatible] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: str + +class Bar(TypedDict): + a: int + b: int + +foo: Foo = {'a': 1, 'b': 'a'} +bar1: Bar = {**foo, 'b': 2} # Incompatible item is overriden +bar2: Bar = {**foo, 'a': 2} # E: Incompatible types (expression has type "str", TypedDict item "b" has type "int") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] From b8b4fa7fab22273670710824feba71ab17b175da Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 12 Jun 2023 01:06:52 +0100 Subject: [PATCH 02/11] Cleanups; update union handling; more tests --- mypy/checkexpr.py | 147 ++++++++++++------ mypy/messages.py | 9 ++ mypy/plugins/default.py | 26 +++- test-data/unit/check-typeddict.test | 222 +++++++++++++++++++++++++++- 4 files changed, 340 insertions(+), 64 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 2e0be8e5fa5c..75f80f651bed 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -687,7 +687,9 @@ def check_typeddict_call( orig_callee: Type | None, ) -> Type: if args and all([ak in (ARG_NAMED, ARG_STAR2) for ak in arg_kinds]): - # ex: Point(x=42, y=1337) + # ex: Point(x=42, y=1337, **extras) + # This is a bit ugly, but this is a price for supporting all possible syntax + # variants for TypedDict constructors. kwargs = zip([StrExpr(n) if n is not None else None for n in arg_names], args) result = self.validate_typeddict_kwargs(kwargs=kwargs, callee=callee) if result is not None: @@ -700,12 +702,12 @@ def check_typeddict_call( if len(args) == 1 and arg_kinds[0] == ARG_POS: unique_arg = args[0] if isinstance(unique_arg, DictExpr): - # ex: Point({'x': 42, 'y': 1337}) + # ex: Point({'x': 42, 'y': 1337, **extras}) return self.check_typeddict_call_with_dict( callee, unique_arg.items, context, orig_callee ) if isinstance(unique_arg, CallExpr) and isinstance(unique_arg.analyzed, DictExpr): - # ex: Point(dict(x=42, y=1337)) + # ex: Point(dict(x=42, y=1337, **extras)) return self.check_typeddict_call_with_dict( callee, unique_arg.analyzed.items, context, orig_callee ) @@ -720,8 +722,11 @@ def check_typeddict_call( def validate_typeddict_kwargs( self, kwargs: Iterable[tuple[Expression | None, Expression]], callee: TypedDictType ) -> tuple[dict[str, list[Expression]], set[str]] | None: + # All (actual or mapped from ** unpacks) expressions that can match given key. result = defaultdict(list) + # Keys that are guaranteed to be present no matter what (e.g. for all items of a union) always_present_keys = set() + for item_name_expr, item_arg in kwargs: if item_name_expr: key_type = self.accept(item_name_expr) @@ -738,49 +743,76 @@ def validate_typeddict_kwargs( ) return None else: + # A directly present key unconditionally shadows all previously found + # values from ** items. + # TODO: for duplicate keys, type-check all values. result[literal_value] = [item_arg] always_present_keys.add(literal_value) else: - with self.chk.local_type_map(), self.msg.filter_errors(): - inferred = get_proper_type(self.accept(item_arg, type_context=callee)) - if isinstance(inferred, TypedDictType): - possible_tds = [inferred] - elif isinstance(inferred, UnionType): - possible_tds = [] - for item in get_proper_types(inferred.relevant_items()): - if isinstance(item, TypedDictType): - possible_tds.append(item) - else: - self.chk.fail("Bad star", item_arg) - return None - else: - self.chk.fail("Bad star", item_arg) - return None - all_keys: set[str] = set() - for td in possible_tds: - all_keys |= td.items.keys() - for key in all_keys: - arg = TempNode( - UnionType.make_union( - [td.items[key] for td in possible_tds if key in td.items] - ) - ) - arg.set_line(item_arg) - if all(key in td.required_keys for td in possible_tds): - always_present_keys.add(key) - if result[key]: - # TODO: stricter checks with strict flag - first = result[key][0] - if isinstance(first, TempNode): - result[key] = [arg] - else: - result[key] = [first, arg] - else: - result[key] = [arg] - else: - result[key].append(arg) + if not self.validate_star_typeddict_item( + item_arg, callee, result, always_present_keys + ): + return None return result, always_present_keys + def validate_star_typeddict_item( + self, + item_arg: Expression, + callee: TypedDictType, + result: dict[str, list[Expression]], + always_present_keys: set[str], + ) -> bool: + """Update keys/expressions from a ** expression in TypedDict constructor. + + Note `result` and `always_present_keys` are updated in place. Return true if the + expression `item_arg` may valid in `callee` TypedDict context. + """ + with self.chk.local_type_map(), self.msg.filter_errors(): + inferred = get_proper_type(self.accept(item_arg, type_context=callee)) + if isinstance(inferred, TypedDictType): + possible_tds = [inferred] + elif isinstance(inferred, UnionType): + possible_tds = [] + for item in get_proper_types(inferred.relevant_items()): + if isinstance(item, TypedDictType): + possible_tds.append(item) + else: + self.msg.unsupported_target_for_star_typeddict(item, item_arg) + return False + else: + self.msg.unsupported_target_for_star_typeddict(inferred, item_arg) + return False + all_keys: set[str] = set() + for td in possible_tds: + all_keys |= td.items.keys() + for key in all_keys: + arg = TempNode( + UnionType.make_union([td.items[key] for td in possible_tds if key in td.items]) + ) + arg.set_line(item_arg) + if all(key in td.required_keys for td in possible_tds): + always_present_keys.add(key) + # Always present keys override previously found values. This is done + # to support use cases like `Config({**defaults, **overrides})`, where + # some `overrides` types are narrower that types in `defaults`, and + # former are too wide for `Config`. + if result[key]: + first = result[key][0] + if not isinstance(first, TempNode): + # We must always preserve any non-synthetic values, so that + # we will accept them even if they are shadowed. + result[key] = [first, arg] + else: + result[key] = [arg] + else: + result[key] = [arg] + else: + # If this key is not required at least in some item of a union + # it may not shadow previous item, so we need to type check both. + result[key].append(arg) + # TODO: detect possibly unsafe ** overrides in --strict-typeddict-update mode. + return True + def match_typeddict_call_with_dict( self, callee: TypedDictType, @@ -859,14 +891,28 @@ def check_typeddict_call_with_kwargs( if not ( callee.required_keys <= always_present_keys and actual_keys <= callee.items.keys() ): - expected_keys = [ - key - for key in callee.items.keys() - if key in callee.required_keys or key in actual_keys - ] - self.msg.unexpected_typeddict_keys( - callee, expected_keys=expected_keys, actual_keys=list(actual_keys), context=context - ) + if not (actual_keys <= callee.items.keys()): + self.msg.unexpected_typeddict_keys( + callee, + expected_keys=[ + key + for key in callee.items.keys() + if key in callee.required_keys or key in actual_keys + ], + actual_keys=list(actual_keys), + context=context, + ) + if not (callee.required_keys <= always_present_keys): + self.msg.unexpected_typeddict_keys( + callee, + expected_keys=[ + key for key in callee.items.keys() if key in callee.required_keys + ], + actual_keys=[ + key for key in always_present_keys if key in callee.required_keys + ], + context=context, + ) if callee.required_keys > actual_keys: # found_set is a sub-set of the required_keys # This means we're missing some keys and as such, we can't @@ -889,6 +935,9 @@ def check_typeddict_call_with_kwargs( with self.msg.filter_errors(), self.chk.local_type_map(): orig_ret_type, _ = self.check_callable_call( infer_callee, + # We use first expression for each key to infer type variables of a generic + # TypedDict. This is a bit arbitrary, but in most cases will work better than + # trying to infer a union or a join. [args[0] for args in kwargs.values()], [ArgKind.ARG_NAMED] * len(kwargs), context, diff --git a/mypy/messages.py b/mypy/messages.py index 9d703a1a974a..85c8a49dc550 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1754,6 +1754,15 @@ def need_annotation_for_var( def explicit_any(self, ctx: Context) -> None: self.fail('Explicit "Any" is not allowed', ctx) + def unsupported_target_for_star_typeddict(self, typ: Type, ctx: Context) -> None: + self.fail( + "Unsupported type {} for ** expansion in TypedDict".format( + format_type(typ, self.options) + ), + ctx, + code=codes.TYPEDDICT_ITEM, + ) + def unexpected_typeddict_keys( self, typ: TypedDictType, diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 8d2ed9fd960c..b1317440163b 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -31,7 +31,9 @@ TypedDictType, TypeOfAny, TypeVarType, + UnionType, get_proper_type, + get_proper_types, ) @@ -407,16 +409,26 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: inferred = get_proper_type( ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type) ) - # TODO: unions + possible_tds = [] if isinstance(inferred, TypedDictType): - arg_type = arg_type.copy_modified( - # TODO: extra keys - required_keys=arg_type.required_keys - | inferred.required_keys + possible_tds = [inferred] + elif isinstance(inferred, UnionType): + possible_tds = [ + t + for t in get_proper_types(inferred.relevant_items()) + if isinstance(t, TypedDictType) + ] + items = [] + for td in possible_tds: + item = arg_type.copy_modified( + required_keys=(arg_type.required_keys | td.required_keys) + & arg_type.items.keys() ) if not ctx.api.options.strict_typeddict_update: - # TODO: extra names - arg_type = arg_type.copy_modified(item_names=list(inferred.items)) + item = item.copy_modified(item_names=list(td.items)) + items.append(item) + if items: + arg_type = make_simplified_union(items) return signature.copy_modified(arg_types=[arg_type]) return signature diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 88a6d8b3d8bb..3cd7083ff6e1 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2915,6 +2915,53 @@ a.update(a) # OK [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testTypedDictFlexibleUpdateUnion] +from typing import Union +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +B = TypedDict("B", {"foo": int}) +C = TypedDict("C", {"bar": int}) + +a = A({"foo": 1, "bar": 2}) +u: Union[B, C] +a.update(u) +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictFlexibleUpdateUnionExtra] +from typing import Union +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +B = TypedDict("B", {"foo": int, "extra": int}) +C = TypedDict("C", {"bar": int, "extra": int}) + +a = A({"foo": 1, "bar": 2}) +u: Union[B, C] +a.update(u) +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictFlexibleUpdateUnionStrict] +# flags: --strict-typeddict-update +from typing import Union, NotRequired +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +A1 = TypedDict("A1", {"foo": int, "bar": NotRequired[int]}) +A2 = TypedDict("A2", {"foo": NotRequired[int], "bar": int}) +B = TypedDict("B", {"foo": int}) +C = TypedDict("C", {"bar": int}) + +a = A({"foo": 1, "bar": 2}) +u: Union[B, C] +a.update(u) # E: Argument 1 to "update" of "TypedDict" has incompatible type "Union[B, C]"; expected "Union[TypedDict({'foo': int, 'bar'?: int}), TypedDict({'foo'?: int, 'bar': int})]" +u2: Union[A1, A2] +a.update(u2) # OK +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + [case testTypedDictUnpackSame] from typing import TypedDict @@ -2922,10 +2969,11 @@ class Foo(TypedDict): a: int b: int -foo1: Foo = {'a': 1, 'b': 1} -foo2: Foo = {**foo1, 'b': 2} +foo1: Foo = {"a": 1, "b": 1} +foo2: Foo = {**foo1, "b": 2} foo3 = Foo(**foo1, b=2) -foo4 = Foo({**foo1, 'b': 2}) +foo4 = Foo({**foo1, "b": 2}) +[builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackCompatible] @@ -2938,8 +2986,9 @@ class Bar(TypedDict): a: int b: int -foo: Foo = {'a': 1} -bar: Bar = {**foo, 'b': 2} +foo: Foo = {"a": 1} +bar: Bar = {**foo, "b": 2} +[builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackIncompatible] @@ -2953,8 +3002,165 @@ class Bar(TypedDict): a: int b: int -foo: Foo = {'a': 1, 'b': 'a'} -bar1: Bar = {**foo, 'b': 2} # Incompatible item is overriden -bar2: Bar = {**foo, 'a': 2} # E: Incompatible types (expression has type "str", TypedDict item "b" has type "int") +foo: Foo = {"a": 1, "b": "a"} +bar1: Bar = {**foo, "b": 2} # Incompatible item is overriden +bar2: Bar = {**foo, "a": 2} # E: Incompatible types (expression has type "str", TypedDict item "b" has type "int") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackNotRequiredKeyIncompatible] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: NotRequired[str] + +class Bar(TypedDict): + a: NotRequired[int] + +foo: Foo = {} +bar: Bar = {**foo} # E: Incompatible types (expression has type "str", TypedDict item "a" has type "int") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackMissingOrExtraKey] +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: int + +foo1: Foo = {"a": 1} +bar1: Bar = {"a": 1, "b": 1} +foo2: Foo = {**bar1} # E: Extra key "b" for TypedDict "Foo" +bar2: Bar = {**foo1} # E: Missing key "b" for TypedDict "Bar" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackNotRequiredKeyExtra] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: NotRequired[int] + +foo1: Foo = {"a": 1} +bar1: Bar = {"a": 1} +foo2: Foo = {**bar1} # E: Extra key "b" for TypedDict "Foo" +bar2: Bar = {**foo1} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackRequiredKeyMissing] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: NotRequired[int] + +class Bar(TypedDict): + a: int + +foo: Foo = {"a": 1} +bar: Bar = {**foo} # E: Missing key "a" for TypedDict "Bar" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackMultiple] +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + b: int + +class Baz(TypedDict): + a: int + b: int + c: int + +foo: Foo = {"a": 1} +bar: Bar = {"b": 1} +baz: Baz = {**foo, **bar, "c": 1} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackNested] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + c: Foo + d: int + +foo: Foo = {"a": 1, "b": 1} +bar: Bar = {"c": foo, "d": 1} +bar2: Bar = {**bar, "c": {**bar["c"], "b": 2}, "d": 2} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackNestedError] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + c: Foo + d: int + +foo: Foo = {"a": 1, "b": 1} +bar: Bar = {"c": foo, "d": 1} +bar2: Bar = {**bar, "c": {**bar["c"], "b": "wrong"}, "d": 2} # E: Incompatible types (expression has type "str", TypedDict item "b" has type "int") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackUntypedDict] +from typing import TypedDict + +class Bar(TypedDict): + pass + +foo: dict = {} +bar: Bar = {**foo} # E: Unsupported type "Dict[Any, Any]" for ** expansion in TypedDict +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackIntoUnion] +from typing import TypedDict, Union + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + b: int + +foo: Foo = {'a': 1} +foo_or_bar: Union[Foo, Bar] = {**foo} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackFromUnion] +from typing import TypedDict, Union + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + b: int + +foo_or_bar: Union[Foo, Bar] = {'b': 1} +foo: Bar = {**foo_or_bar} # E: Extra key "a" for TypedDict "Bar" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] From 7925535205d845d826f308bbbd9747ed45978355 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 12 Jun 2023 23:30:35 +0100 Subject: [PATCH 03/11] Add inference test case --- test-data/unit/check-typeddict.test | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 3cd7083ff6e1..56f0a8e9c9dd 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -3164,3 +3164,21 @@ foo_or_bar: Union[Foo, Bar] = {'b': 1} foo: Bar = {**foo_or_bar} # E: Extra key "a" for TypedDict "Bar" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackInference] +from typing import TypedDict, Generic, TypeVar + +class Foo(TypedDict): + a: int + b: str + +T = TypeVar("T") +class TD(TypedDict, Generic[T]): + a: T + b: str + +foo: Foo +bar = TD(**foo) +reveal_type(bar) # N: Revealed type is "TypedDict('__main__.TD', {'a': builtins.int, 'b': builtins.str})" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] From 1aa42188aafc2df317ad54ad7b8e5ec4921f3908 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 12 Jun 2023 23:41:14 +0100 Subject: [PATCH 04/11] Add another test case from issue --- test-data/unit/check-typeddict.test | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 56f0a8e9c9dd..0c6685fea2ad 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -3125,6 +3125,18 @@ bar2: Bar = {**bar, "c": {**bar["c"], "b": "wrong"}, "d": 2} # E: Incompatible [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testTypedDictUnpackOverrideRequired] +from mypy_extensions import TypedDict + +Details = TypedDict('Details', {'first_name': str, 'last_name': str}) +DetailsSubset = TypedDict('DetailsSubset', {'first_name': str, 'last_name': str}, total=False) +defaults: Details = {'first_name': 'John', 'last_name': 'Luther'} + +def generate(data: DetailsSubset) -> Details: + return {**defaults, **data} # OK +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + [case testTypedDictUnpackUntypedDict] from typing import TypedDict From 57f57c4ce105731c91cf57dd3bf7d95a6116fcfb Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 13 Jun 2023 00:50:13 +0100 Subject: [PATCH 05/11] Handle strictnes also for star items --- mypy/checkexpr.py | 15 ++++++++++++++- mypy/messages.py | 9 +++++++++ test-data/unit/check-typeddict.test | 16 ++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 75f80f651bed..4307733cf1a0 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -726,6 +726,8 @@ def validate_typeddict_kwargs( result = defaultdict(list) # Keys that are guaranteed to be present no matter what (e.g. for all items of a union) always_present_keys = set() + # Indicates latest encountered ** unpack among items. + last_star_found = None for item_name_expr, item_arg in kwargs: if item_name_expr: @@ -749,10 +751,22 @@ def validate_typeddict_kwargs( result[literal_value] = [item_arg] always_present_keys.add(literal_value) else: + last_star_found = item_arg if not self.validate_star_typeddict_item( item_arg, callee, result, always_present_keys ): return None + if self.chk.options.strict_typeddict_update and last_star_found is not None: + absent_keys = [] + for key in callee.items: + if key not in callee.required_keys and key not in result: + absent_keys.append(key) + if absent_keys: + # Having an optional key not explicitly declared by a ** unpacked + # TypedDict is unsafe, it may be an (incompatible) subtype at runtime. + # TODO: catch the cases where a declared key is overridden by a subsequent + # ** item without it (and not again overriden with complete ** item). + self.msg.non_required_keys_absent_with_star(absent_keys, last_star_found) return result, always_present_keys def validate_star_typeddict_item( @@ -810,7 +824,6 @@ def validate_star_typeddict_item( # If this key is not required at least in some item of a union # it may not shadow previous item, so we need to type check both. result[key].append(arg) - # TODO: detect possibly unsafe ** overrides in --strict-typeddict-update mode. return True def match_typeddict_call_with_dict( diff --git a/mypy/messages.py b/mypy/messages.py index 85c8a49dc550..2974617f025c 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1763,6 +1763,15 @@ def unsupported_target_for_star_typeddict(self, typ: Type, ctx: Context) -> None code=codes.TYPEDDICT_ITEM, ) + def non_required_keys_absent_with_star(self, keys: list[str], ctx: Context) -> None: + self.fail( + "Non-required {} not explicitly found in any ** item".format( + format_key_list(keys, short=True) + ), + ctx, + code=codes.TYPEDDICT_ITEM, + ) + def unexpected_typeddict_keys( self, typ: TypedDictType, diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 0c6685fea2ad..67d9477457f3 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -3194,3 +3194,19 @@ bar = TD(**foo) reveal_type(bar) # N: Revealed type is "TypedDict('__main__.TD', {'a': builtins.int, 'b': builtins.str})" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackStrictMode] +# flags: --strict-typeddict-update +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: NotRequired[int] + +foo: Foo +bar: Bar = {**foo} # E: Non-required key "b" not explicitly found in any ** item +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] From 4d2f9619ec2821eab1568fb7f48e855c7153fad9 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 13 Jun 2023 10:26:02 +0100 Subject: [PATCH 06/11] Allow Any in star unpacks; add one more union test --- mypy/checkexpr.py | 14 ++++++++++--- test-data/unit/check-typeddict.test | 32 +++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 4307733cf1a0..9605d199afb5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -783,17 +783,17 @@ def validate_star_typeddict_item( """ with self.chk.local_type_map(), self.msg.filter_errors(): inferred = get_proper_type(self.accept(item_arg, type_context=callee)) + possible_tds = [] if isinstance(inferred, TypedDictType): possible_tds = [inferred] elif isinstance(inferred, UnionType): - possible_tds = [] for item in get_proper_types(inferred.relevant_items()): if isinstance(item, TypedDictType): possible_tds.append(item) - else: + elif not self.valid_unpack_fallback_item(item): self.msg.unsupported_target_for_star_typeddict(item, item_arg) return False - else: + elif not self.valid_unpack_fallback_item(inferred): self.msg.unsupported_target_for_star_typeddict(inferred, item_arg) return False all_keys: set[str] = set() @@ -826,6 +826,14 @@ def validate_star_typeddict_item( result[key].append(arg) return True + def valid_unpack_fallback_item(self, typ: ProperType) -> bool: + if isinstance(typ, AnyType): + return True + if not isinstance(typ, Instance) or not typ.type.has_base("typing.Mapping"): + return False + mapped = map_instance_to_supertype(typ, self.chk.lookup_typeinfo("typing.Mapping")) + return all(isinstance(a, AnyType) for a in get_proper_types(mapped.args)) + def match_typeddict_call_with_dict( self, callee: TypedDictType, diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 67d9477457f3..e92b5e4d5079 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -3177,6 +3177,22 @@ foo: Bar = {**foo_or_bar} # E: Extra key "a" for TypedDict "Bar" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] +[case testTypedDictUnpackUnionRequiredMissing] +from typing import TypedDict, NotRequired, Union + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + a: int + b: NotRequired[int] + +foo_or_bar: Union[Foo, Bar] = {"a": 1} +foo: Foo = {**foo_or_bar} # E: Missing key "b" for TypedDict "Foo" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + [case testTypedDictUnpackInference] from typing import TypedDict, Generic, TypeVar @@ -3210,3 +3226,19 @@ foo: Foo bar: Bar = {**foo} # E: Non-required key "b" not explicitly found in any ** item [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackAny] +from typing import Any, TypedDict, NotRequired, Dict, Union + +class Foo(TypedDict): + a: int + b: NotRequired[int] + +x: Any +y: Dict[Any, Any] +z: Union[Any, Dict[Any, Any]] +t1: Foo = {**x} # E: Missing key "a" for TypedDict "Foo" +t2: Foo = {**y} # E: Missing key "a" for TypedDict "Foo" +t3: Foo = {**z} # E: Missing key "a" for TypedDict "Foo" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] From 7b0ecb57c6ab653398fac4a0cfc9b6948edda8a4 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 13 Jun 2023 10:29:16 +0100 Subject: [PATCH 07/11] Update test --- test-data/unit/check-typeddict.test | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index e92b5e4d5079..edd60324d1b3 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -3138,13 +3138,13 @@ def generate(data: DetailsSubset) -> Details: [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackUntypedDict] -from typing import TypedDict +from typing import Any, Dict, TypedDict class Bar(TypedDict): pass -foo: dict = {} -bar: Bar = {**foo} # E: Unsupported type "Dict[Any, Any]" for ** expansion in TypedDict +foo: Dict[str, Any] = {} +bar: Bar = {**foo} # E: Unsupported type "Dict[str, Any]" for ** expansion in TypedDict [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] From d6ed5cfcb351390cd410fe91e6859e03b48255f3 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 13 Jun 2023 14:01:47 +0100 Subject: [PATCH 08/11] Support plain dict syntax as well --- mypy/semanal.py | 4 ++-- test-data/unit/check-typeddict.test | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index 073bde661617..0fdad2222608 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -5084,14 +5084,14 @@ def translate_dict_call(self, call: CallExpr) -> DictExpr | None: For other variants of dict(...), return None. """ - if not all(kind == ARG_NAMED for kind in call.arg_kinds): + if not all(kind in (ARG_NAMED, ARG_STAR2) for kind in call.arg_kinds): # Must still accept those args. for a in call.args: a.accept(self) return None expr = DictExpr( [ - (StrExpr(cast(str, key)), value) # since they are all ARG_NAMED + (StrExpr(key) if key is not None else None, value) for key, value in zip(call.arg_names, call.args) ] ) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index edd60324d1b3..45deae2cec1e 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2973,6 +2973,7 @@ foo1: Foo = {"a": 1, "b": 1} foo2: Foo = {**foo1, "b": 2} foo3 = Foo(**foo1, b=2) foo4 = Foo({**foo1, "b": 2}) +foo5 = Foo(dict(**foo1, b=2)) [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] From 98cdc481966fea212fbb398d1d88b45020f9d7cc Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 22 Jun 2023 23:19:39 +0100 Subject: [PATCH 09/11] Address CR --- test-data/unit/check-typeddict.test | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 45deae2cec1e..15a4e43f47b2 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2963,6 +2963,7 @@ a.update(u2) # OK [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackSame] +# flags: --strict-typeddict-update from typing import TypedDict class Foo(TypedDict): @@ -2978,6 +2979,7 @@ foo5 = Foo(dict(**foo1, b=2)) [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackCompatible] +# flags: --strict-typeddict-update from typing import TypedDict class Foo(TypedDict): @@ -3073,6 +3075,7 @@ bar: Bar = {**foo} # E: Missing key "a" for TypedDict "Bar" [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackMultiple] +# flags: --strict-typeddict-update from typing import TypedDict class Foo(TypedDict): From a98f555de2b1fe6babeb4c039c0485b9bbe6f24b Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 23 Jun 2023 23:48:35 +0100 Subject: [PATCH 10/11] Merge flags and deprecate old one --- mypy/checkexpr.py | 2 +- mypy/main.py | 21 ++++++++++--------- mypy/options.py | 8 +++---- mypy/plugins/default.py | 2 +- mypy/subtypes.py | 10 +++++++-- .../unit/check-parameter-specification.test | 2 +- test-data/unit/check-typeddict.test | 12 +++++------ 7 files changed, 32 insertions(+), 25 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3b423697ad14..986e58c21762 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -766,7 +766,7 @@ def validate_typeddict_kwargs( item_arg, callee, result, always_present_keys ): return None - if self.chk.options.strict_typeddict_update and last_star_found is not None: + if self.chk.options.extra_checks and last_star_found is not None: absent_keys = [] for key in callee.items: if key not in callee.required_keys and key not in result: diff --git a/mypy/main.py b/mypy/main.py index 6d48220a200c..22ff3e32a718 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -826,18 +826,12 @@ def add_invertible_flag( ) add_invertible_flag( - "--strict-concatenate", + "--extra-checks", default=False, strict_flag=True, - help="Make arguments prepended via Concatenate be truly positional-only", - group=strictness_group, - ) - - add_invertible_flag( - "--strict-typeddict-update", - default=False, - strict_flag=True, - help="Disallow partial overlap in TypedDict update (including ** in constructor)", + help="Enable additional checks that are technically correct but may be impractical " + "in real code. For example, this prohibits partial overlap in TypedDict updates, " + "and makes arguments prepended via Concatenate positional-only", group=strictness_group, ) @@ -1163,6 +1157,8 @@ def add_invertible_flag( parser.add_argument( "--disable-memoryview-promotion", action="store_true", help=argparse.SUPPRESS ) + # This flag is deprecated, it has been moved to --extra-checks + parser.add_argument("--strict-concatenate", action="store_true", help=argparse.SUPPRESS) # options specifying code to check code_group = parser.add_argument_group( @@ -1234,8 +1230,11 @@ def add_invertible_flag( parser.error(f"Cannot find config file '{config_file}'") options = Options() + strict_option_set = False def set_strict_flags() -> None: + nonlocal strict_option_set + strict_option_set = True for dest, value in strict_flag_assignments: setattr(options, dest, value) @@ -1387,6 +1386,8 @@ def set_strict_flags() -> None: "Warning: --enable-recursive-aliases is deprecated;" " recursive types are enabled by default" ) + if options.strict_concatenate and not strict_option_set: + print("Warning: --strict-concatenate is deprecated; use --extra-checks instead") # Set target. if special_opts.modules + special_opts.packages: diff --git a/mypy/options.py b/mypy/options.py index 0d8e46af8a93..e1d731c1124c 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -40,6 +40,7 @@ class BuildType: "disallow_untyped_defs", "enable_error_code", "enabled_error_codes", + "extra_checks", "follow_imports_for_stubs", "follow_imports", "ignore_errors", @@ -51,7 +52,6 @@ class BuildType: "strict_concatenate", "strict_equality", "strict_optional", - "strict_typeddict_update", "warn_no_return", "warn_return_any", "warn_unreachable", @@ -201,11 +201,11 @@ def __init__(self) -> None: # This makes 1 == '1', 1 in ['1'], and 1 is '1' errors. self.strict_equality = False - # Make arguments prepended via Concatenate be truly positional-only. + # Deprecated, use extra_checks instead. self.strict_concatenate = False - # Disallow partial overlap in TypedDict update (including ** in constructor). - self.strict_typeddict_update = False + # Enable additional checks that are technically correct but impractical. + self.extra_checks = False # Report an error for any branches inferred to be unreachable as a result of # type analysis. diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 5af4d27cbff6..f5dea0621177 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -426,7 +426,7 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: required_keys=(arg_type.required_keys | td.required_keys) & arg_type.items.keys() ) - if not ctx.api.options.strict_typeddict_update: + if not ctx.api.options.extra_checks: item = item.copy_modified(item_names=list(td.items)) items.append(item) if items: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a3b28a3e24de..c9de56edfa36 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -694,7 +694,9 @@ def visit_callable_type(self, left: CallableType) -> bool: right, is_compat=self._is_subtype, ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, - strict_concatenate=self.options.strict_concatenate if self.options else True, + strict_concatenate=(self.options.extra_checks or self.options.strict_concatenate) + if self.options + else True, ) elif isinstance(right, Overloaded): return all(self._is_subtype(left, item) for item in right.items) @@ -858,7 +860,11 @@ def visit_overloaded(self, left: Overloaded) -> bool: else: # If this one overlaps with the supertype in any way, but it wasn't # an exact match, then it's a potential error. - strict_concat = self.options.strict_concatenate if self.options else True + strict_concat = ( + (self.options.extra_checks or self.options.strict_concatenate) + if self.options + else True + ) if left_index not in matched_overloads and ( is_callable_compatible( left_item, diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index cafcaca0a14c..bebbbf4b1676 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -570,7 +570,7 @@ reveal_type(f(n)) # N: Revealed type is "def (builtins.int, builtins.bytes) -> [builtins fixtures/paramspec.pyi] [case testParamSpecConcatenateNamedArgs] -# flags: --python-version 3.8 --strict-concatenate +# flags: --python-version 3.8 --extra-checks # this is one noticeable deviation from PEP but I believe it is for the better from typing_extensions import ParamSpec, Concatenate from typing import Callable, TypeVar diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 15a4e43f47b2..4d2d64848515 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2901,7 +2901,7 @@ a.update(a) [typing fixtures/typing-typeddict.pyi] [case testTypedDictStrictUpdate] -# flags: --strict-typeddict-update +# flags: --extra-checks from mypy_extensions import TypedDict A = TypedDict("A", {"foo": int, "bar": int}) @@ -2944,7 +2944,7 @@ a.update(u) [typing fixtures/typing-typeddict.pyi] [case testTypedDictFlexibleUpdateUnionStrict] -# flags: --strict-typeddict-update +# flags: --extra-checks from typing import Union, NotRequired from mypy_extensions import TypedDict @@ -2963,7 +2963,7 @@ a.update(u2) # OK [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackSame] -# flags: --strict-typeddict-update +# flags: --extra-checks from typing import TypedDict class Foo(TypedDict): @@ -2979,7 +2979,7 @@ foo5 = Foo(dict(**foo1, b=2)) [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackCompatible] -# flags: --strict-typeddict-update +# flags: --extra-checks from typing import TypedDict class Foo(TypedDict): @@ -3075,7 +3075,7 @@ bar: Bar = {**foo} # E: Missing key "a" for TypedDict "Bar" [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackMultiple] -# flags: --strict-typeddict-update +# flags: --extra-checks from typing import TypedDict class Foo(TypedDict): @@ -3216,7 +3216,7 @@ reveal_type(bar) # N: Revealed type is "TypedDict('__main__.TD', {'a': builtins [typing fixtures/typing-typeddict.pyi] [case testTypedDictUnpackStrictMode] -# flags: --strict-typeddict-update +# flags: --extra-checks from typing import TypedDict, NotRequired class Foo(TypedDict): From 12fcb4fbd1719a16a636964a492bc37161d8c8b8 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 25 Jun 2023 23:47:04 +0100 Subject: [PATCH 11/11] Add docs for new flag --- docs/source/command_line.rst | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/source/command_line.rst b/docs/source/command_line.rst index 2809294092ab..d9de5cd8f9bd 100644 --- a/docs/source/command_line.rst +++ b/docs/source/command_line.rst @@ -612,6 +612,34 @@ of the above sections. assert text is not None # OK, check against None is allowed as a special case. +.. option:: --extra-checks + + This flag enables additional checks that are technically correct but may be + impractical in real code. In particular, it prohibits partial overlap in + ``TypedDict`` updates, and makes arguments prepended via ``Concatenate`` + positional-only. For example: + + .. code-block:: python + + from typing import TypedDict + + class Foo(TypedDict): + a: int + + class Bar(TypedDict): + a: int + b: int + + def test(foo: Foo, bar: Bar) -> None: + # This is technically unsafe since foo can have a subtype of Foo at + # runtime, where type of key "b" is incompatible with int, see below + bar.update(foo) + + class Bad(Foo): + b: str + bad: Bad = {"a": 0, "b": "no"} + test(bad, bar) + .. option:: --strict This flag mode enables all optional error checking flags. You can see the