Skip to content

Commit 8290bb8

Browse files
authored
Support flexible TypedDict creation/update (#15425)
Fixes #9408 Fixes #4122 Fixes #6462 Supersedes #13353 This PR enables two similar technically unsafe behaviors for TypedDicts, as @JukkaL explained in #6462 (comment) allowing an "incomplete" TypedDict as an argument to `.update()` is technically unsafe (and a similar argument applies to `**` syntax in TypedDict literals). These are however very common patterns (judging from number of duplicates to above issues), so I think we should support them. Here is what I propose: * Always support cases that are safe (like passing the type itself to `update`) * Allow popular but technically unsafe cases _by default_ * Have a new flag (as part of `--strict`) to fall back to current behavior Note that unfortunately we can't use just a custom new error code, since we need to conditionally tweak some types in a plugin. Btw there are couple TODOs I add here: * First is for unsafe behavior for repeated TypedDict keys. This is not new, I just noticed it when working on this * Second is for tricky corner case involving multiple `**` items where we may have false-negatives in strict mode. Note that I don't test all the possible combinations here (since the phase space is huge), but I think I am testing all main ingredients (and I will be glad to add more if needed): * All syntax variants for TypedDicts creation are handled * Various shadowing/overrides scenarios * Required vs non-required keys handling * Union types (both as item and target types) * Inference for generic TypedDicts * New strictness flag More than half of the tests I took from the original PR #13353
1 parent 7ce3568 commit 8290bb8

11 files changed

+659
-69
lines changed

docs/source/command_line.rst

+28
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,34 @@ of the above sections.
612612
613613
assert text is not None # OK, check against None is allowed as a special case.
614614
615+
.. option:: --extra-checks
616+
617+
This flag enables additional checks that are technically correct but may be
618+
impractical in real code. In particular, it prohibits partial overlap in
619+
``TypedDict`` updates, and makes arguments prepended via ``Concatenate``
620+
positional-only. For example:
621+
622+
.. code-block:: python
623+
624+
from typing import TypedDict
625+
626+
class Foo(TypedDict):
627+
a: int
628+
629+
class Bar(TypedDict):
630+
a: int
631+
b: int
632+
633+
def test(foo: Foo, bar: Bar) -> None:
634+
# This is technically unsafe since foo can have a subtype of Foo at
635+
# runtime, where type of key "b" is incompatible with int, see below
636+
bar.update(foo)
637+
638+
class Bad(Foo):
639+
b: str
640+
bad: Bad = {"a": 0, "b": "no"}
641+
test(bad, bar)
642+
615643
.. option:: --strict
616644

617645
This flag mode enables all optional error checking flags. You can see the

mypy/checkexpr.py

+194-61
Large diffs are not rendered by default.

mypy/main.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -826,10 +826,12 @@ def add_invertible_flag(
826826
)
827827

828828
add_invertible_flag(
829-
"--strict-concatenate",
829+
"--extra-checks",
830830
default=False,
831831
strict_flag=True,
832-
help="Make arguments prepended via Concatenate be truly positional-only",
832+
help="Enable additional checks that are technically correct but may be impractical "
833+
"in real code. For example, this prohibits partial overlap in TypedDict updates, "
834+
"and makes arguments prepended via Concatenate positional-only",
833835
group=strictness_group,
834836
)
835837

@@ -1155,6 +1157,8 @@ def add_invertible_flag(
11551157
parser.add_argument(
11561158
"--disable-memoryview-promotion", action="store_true", help=argparse.SUPPRESS
11571159
)
1160+
# This flag is deprecated, it has been moved to --extra-checks
1161+
parser.add_argument("--strict-concatenate", action="store_true", help=argparse.SUPPRESS)
11581162

11591163
# options specifying code to check
11601164
code_group = parser.add_argument_group(
@@ -1226,8 +1230,11 @@ def add_invertible_flag(
12261230
parser.error(f"Cannot find config file '{config_file}'")
12271231

12281232
options = Options()
1233+
strict_option_set = False
12291234

12301235
def set_strict_flags() -> None:
1236+
nonlocal strict_option_set
1237+
strict_option_set = True
12311238
for dest, value in strict_flag_assignments:
12321239
setattr(options, dest, value)
12331240

@@ -1379,6 +1386,8 @@ def set_strict_flags() -> None:
13791386
"Warning: --enable-recursive-aliases is deprecated;"
13801387
" recursive types are enabled by default"
13811388
)
1389+
if options.strict_concatenate and not strict_option_set:
1390+
print("Warning: --strict-concatenate is deprecated; use --extra-checks instead")
13821391

13831392
# Set target.
13841393
if special_opts.modules + special_opts.packages:

mypy/messages.py

+18
Original file line numberDiff line numberDiff line change
@@ -1757,6 +1757,24 @@ def need_annotation_for_var(
17571757
def explicit_any(self, ctx: Context) -> None:
17581758
self.fail('Explicit "Any" is not allowed', ctx)
17591759

1760+
def unsupported_target_for_star_typeddict(self, typ: Type, ctx: Context) -> None:
1761+
self.fail(
1762+
"Unsupported type {} for ** expansion in TypedDict".format(
1763+
format_type(typ, self.options)
1764+
),
1765+
ctx,
1766+
code=codes.TYPEDDICT_ITEM,
1767+
)
1768+
1769+
def non_required_keys_absent_with_star(self, keys: list[str], ctx: Context) -> None:
1770+
self.fail(
1771+
"Non-required {} not explicitly found in any ** item".format(
1772+
format_key_list(keys, short=True)
1773+
),
1774+
ctx,
1775+
code=codes.TYPEDDICT_ITEM,
1776+
)
1777+
17601778
def unexpected_typeddict_keys(
17611779
self,
17621780
typ: TypedDictType,

mypy/options.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class BuildType:
4040
"disallow_untyped_defs",
4141
"enable_error_code",
4242
"enabled_error_codes",
43+
"extra_checks",
4344
"follow_imports_for_stubs",
4445
"follow_imports",
4546
"ignore_errors",
@@ -200,9 +201,12 @@ def __init__(self) -> None:
200201
# This makes 1 == '1', 1 in ['1'], and 1 is '1' errors.
201202
self.strict_equality = False
202203

203-
# Make arguments prepended via Concatenate be truly positional-only.
204+
# Deprecated, use extra_checks instead.
204205
self.strict_concatenate = False
205206

207+
# Enable additional checks that are technically correct but impractical.
208+
self.extra_checks = False
209+
206210
# Report an error for any branches inferred to be unreachable as a result of
207211
# type analysis.
208212
self.warn_unreachable = False

mypy/plugins/default.py

+27
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
TypedDictType,
3232
TypeOfAny,
3333
TypeVarType,
34+
UnionType,
3435
get_proper_type,
36+
get_proper_types,
3537
)
3638

3739

@@ -404,6 +406,31 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
404406
assert isinstance(arg_type, TypedDictType)
405407
arg_type = arg_type.as_anonymous()
406408
arg_type = arg_type.copy_modified(required_keys=set())
409+
if ctx.args and ctx.args[0]:
410+
with ctx.api.msg.filter_errors():
411+
inferred = get_proper_type(
412+
ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
413+
)
414+
possible_tds = []
415+
if isinstance(inferred, TypedDictType):
416+
possible_tds = [inferred]
417+
elif isinstance(inferred, UnionType):
418+
possible_tds = [
419+
t
420+
for t in get_proper_types(inferred.relevant_items())
421+
if isinstance(t, TypedDictType)
422+
]
423+
items = []
424+
for td in possible_tds:
425+
item = arg_type.copy_modified(
426+
required_keys=(arg_type.required_keys | td.required_keys)
427+
& arg_type.items.keys()
428+
)
429+
if not ctx.api.options.extra_checks:
430+
item = item.copy_modified(item_names=list(td.items))
431+
items.append(item)
432+
if items:
433+
arg_type = make_simplified_union(items)
407434
return signature.copy_modified(arg_types=[arg_type])
408435
return signature
409436

mypy/semanal.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5084,14 +5084,14 @@ def translate_dict_call(self, call: CallExpr) -> DictExpr | None:
50845084
50855085
For other variants of dict(...), return None.
50865086
"""
5087-
if not all(kind == ARG_NAMED for kind in call.arg_kinds):
5087+
if not all(kind in (ARG_NAMED, ARG_STAR2) for kind in call.arg_kinds):
50885088
# Must still accept those args.
50895089
for a in call.args:
50905090
a.accept(self)
50915091
return None
50925092
expr = DictExpr(
50935093
[
5094-
(StrExpr(cast(str, key)), value) # since they are all ARG_NAMED
5094+
(StrExpr(key) if key is not None else None, value)
50955095
for key, value in zip(call.arg_names, call.args)
50965096
]
50975097
)

mypy/subtypes.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,9 @@ def visit_callable_type(self, left: CallableType) -> bool:
694694
right,
695695
is_compat=self._is_subtype,
696696
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
697-
strict_concatenate=self.options.strict_concatenate if self.options else True,
697+
strict_concatenate=(self.options.extra_checks or self.options.strict_concatenate)
698+
if self.options
699+
else True,
698700
)
699701
elif isinstance(right, Overloaded):
700702
return all(self._is_subtype(left, item) for item in right.items)
@@ -858,7 +860,11 @@ def visit_overloaded(self, left: Overloaded) -> bool:
858860
else:
859861
# If this one overlaps with the supertype in any way, but it wasn't
860862
# an exact match, then it's a potential error.
861-
strict_concat = self.options.strict_concatenate if self.options else True
863+
strict_concat = (
864+
(self.options.extra_checks or self.options.strict_concatenate)
865+
if self.options
866+
else True
867+
)
862868
if left_index not in matched_overloads and (
863869
is_callable_compatible(
864870
left_item,

mypy/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -2437,6 +2437,7 @@ def copy_modified(
24372437
*,
24382438
fallback: Instance | None = None,
24392439
item_types: list[Type] | None = None,
2440+
item_names: list[str] | None = None,
24402441
required_keys: set[str] | None = None,
24412442
) -> TypedDictType:
24422443
if fallback is None:
@@ -2447,6 +2448,9 @@ def copy_modified(
24472448
items = dict(zip(self.items, item_types))
24482449
if required_keys is None:
24492450
required_keys = self.required_keys
2451+
if item_names is not None:
2452+
items = {k: v for (k, v) in items.items() if k in item_names}
2453+
required_keys &= set(item_names)
24502454
return TypedDictType(items, required_keys, fallback, self.line, self.column)
24512455

24522456
def create_anonymous_fallback(self) -> Instance:

test-data/unit/check-parameter-specification.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ reveal_type(f(n)) # N: Revealed type is "def (builtins.int, builtins.bytes) ->
570570
[builtins fixtures/paramspec.pyi]
571571

572572
[case testParamSpecConcatenateNamedArgs]
573-
# flags: --python-version 3.8 --strict-concatenate
573+
# flags: --python-version 3.8 --extra-checks
574574
# this is one noticeable deviation from PEP but I believe it is for the better
575575
from typing_extensions import ParamSpec, Concatenate
576576
from typing import Callable, TypeVar

0 commit comments

Comments
 (0)