Skip to content

Commit 35bc1a2

Browse files
ilevkivskyijhance
authored andcommitted
Allow using TypedDict for more precise typing of **kwds (#13471)
Fixes #4441 This uses a different approach than the initial attempt, but I re-used some of the test cases from the older PR. The initial idea was to eagerly expand the signature of the function during semantic analysis, but it didn't work well with fine-grained mode and also mypy in general relies on function definition and its type being consistent (and rewriting `FuncDef` sounds too sketchy). So instead I add a boolean flag to `CallableType` to indicate whether type of `**kwargs` is each item type or the "packed" type. I also add few helpers and safety net in form of a `NewType()`, but in general I am surprised how few places needed normalizing the signatures (because most relevant code paths go through `check_callable_call()` and/or `is_callable_compatible()`). Currently `Unpack[...]` is hidden behind `--enable-incomplete-features`, so this will be too, but IMO this part is 99% complete (you can see even some more exotic use cases like generic TypedDicts and callback protocols in test cases).
1 parent dd2e020 commit 35bc1a2

16 files changed

+505
-21
lines changed

mypy/checker.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -728,9 +728,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
728728
# This is to match the direction the implementation's return
729729
# needs to be compatible in.
730730
if impl_type.variables:
731-
impl = unify_generic_callable(
732-
impl_type,
733-
sig1,
731+
impl: CallableType | None = unify_generic_callable(
732+
# Normalize both before unifying
733+
impl_type.with_unpacked_kwargs(),
734+
sig1.with_unpacked_kwargs(),
734735
ignore_return=False,
735736
return_constraint_direction=SUPERTYPE_OF,
736737
)
@@ -1165,7 +1166,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str | None) ->
11651166
# builtins.tuple[T] is typing.Tuple[T, ...]
11661167
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
11671168
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
1168-
if not isinstance(arg_type, ParamSpecType):
1169+
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
11691170
arg_type = self.named_generic_type(
11701171
"builtins.dict", [self.str_type(), arg_type]
11711172
)
@@ -1887,6 +1888,13 @@ def check_override(
18871888

18881889
if fail:
18891890
emitted_msg = False
1891+
1892+
# Normalize signatures, so we get better diagnostics.
1893+
if isinstance(override, (CallableType, Overloaded)):
1894+
override = override.with_unpacked_kwargs()
1895+
if isinstance(original, (CallableType, Overloaded)):
1896+
original = original.with_unpacked_kwargs()
1897+
18901898
if (
18911899
isinstance(override, CallableType)
18921900
and isinstance(original, CallableType)

mypy/checkexpr.py

+4
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,8 @@ def check_callable_call(
13221322
13231323
See the docstring of check_call for more information.
13241324
"""
1325+
# Always unpack **kwargs before checking a call.
1326+
callee = callee.with_unpacked_kwargs()
13251327
if callable_name is None and callee.name:
13261328
callable_name = callee.name
13271329
ret_type = get_proper_type(callee.ret_type)
@@ -2057,6 +2059,8 @@ def check_overload_call(
20572059
context: Context,
20582060
) -> tuple[Type, Type]:
20592061
"""Checks a call to an overloaded function."""
2062+
# Normalize unpacked kwargs before checking the call.
2063+
callee = callee.with_unpacked_kwargs()
20602064
arg_types = self.infer_arg_types_in_empty_context(args)
20612065
# Step 1: Filter call targets to remove ones where the argument counts don't match
20622066
plausible_targets = self.plausible_overload_call_targets(

mypy/constraints.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -735,9 +735,13 @@ def infer_constraints_from_protocol_members(
735735
return res
736736

737737
def visit_callable_type(self, template: CallableType) -> list[Constraint]:
738+
# Normalize callables before matching against each other.
739+
# Note that non-normalized callables can be created in annotations
740+
# using e.g. callback protocols.
741+
template = template.with_unpacked_kwargs()
738742
if isinstance(self.actual, CallableType):
739743
res: list[Constraint] = []
740-
cactual = self.actual
744+
cactual = self.actual.with_unpacked_kwargs()
741745
param_spec = template.param_spec()
742746
if param_spec is None:
743747
# FIX verify argument counts

mypy/join.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from typing import Tuple
6+
57
import mypy.typeops
68
from mypy.maptype import map_instance_to_supertype
79
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
@@ -141,7 +143,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
141143

142144
def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
143145
"""Return a simple least upper bound given the declared type."""
144-
# TODO: check infinite recursion for aliases here.
146+
# TODO: check infinite recursion for aliases here?
145147
declaration = get_proper_type(declaration)
146148
s = get_proper_type(s)
147149
t = get_proper_type(t)
@@ -172,6 +174,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
172174
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
173175
s, t = t, s
174176

177+
# Meets/joins require callable type normalization.
178+
s, t = normalize_callables(s, t)
179+
175180
value = t.accept(TypeJoinVisitor(s))
176181
if declaration is None or is_subtype(value, declaration):
177182
return value
@@ -229,6 +234,9 @@ def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None)
229234
elif isinstance(t, PlaceholderType):
230235
return AnyType(TypeOfAny.from_error)
231236

237+
# Meets/joins require callable type normalization.
238+
s, t = normalize_callables(s, t)
239+
232240
# Use a visitor to handle non-trivial cases.
233241
return t.accept(TypeJoinVisitor(s, instance_joiner))
234242

@@ -528,6 +536,14 @@ def is_better(t: Type, s: Type) -> bool:
528536
return False
529537

530538

539+
def normalize_callables(s: ProperType, t: ProperType) -> Tuple[ProperType, ProperType]:
540+
if isinstance(s, (CallableType, Overloaded)):
541+
s = s.with_unpacked_kwargs()
542+
if isinstance(t, (CallableType, Overloaded)):
543+
t = t.with_unpacked_kwargs()
544+
return s, t
545+
546+
531547
def is_similar_callables(t: CallableType, s: CallableType) -> bool:
532548
"""Return True if t and s have identical numbers of
533549
arguments, default arguments and varargs.

mypy/meet.py

+4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def meet_types(s: Type, t: Type) -> ProperType:
7878
return t
7979
if isinstance(s, UnionType) and not isinstance(t, UnionType):
8080
s, t = t, s
81+
82+
# Meets/joins require callable type normalization.
83+
s, t = join.normalize_callables(s, t)
84+
8185
return t.accept(TypeMeetVisitor(s))
8286

8387

mypy/messages.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2391,7 +2391,10 @@ def [T <: int] f(self, x: int, y: T) -> None
23912391
name = tp.arg_names[i]
23922392
if name:
23932393
s += name + ": "
2394-
s += format_type_bare(tp.arg_types[i])
2394+
type_str = format_type_bare(tp.arg_types[i])
2395+
if tp.arg_kinds[i] == ARG_STAR2 and tp.unpack_kwargs:
2396+
type_str = f"Unpack[{type_str}]"
2397+
s += type_str
23952398
if tp.arg_kinds[i].is_optional():
23962399
s += " = ..."
23972400

mypy/semanal.py

+26
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@
263263
TypeVarLikeType,
264264
TypeVarType,
265265
UnboundType,
266+
UnpackType,
266267
get_proper_type,
267268
get_proper_types,
268269
invalid_recursive_alias,
@@ -832,6 +833,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
832833
self.defer(defn)
833834
return
834835
assert isinstance(result, ProperType)
836+
if isinstance(result, CallableType):
837+
result = self.remove_unpack_kwargs(defn, result)
835838
defn.type = result
836839
self.add_type_alias_deps(analyzer.aliases_used)
837840
self.check_function_signature(defn)
@@ -874,6 +877,29 @@ def analyze_func_def(self, defn: FuncDef) -> None:
874877
defn.type = defn.type.copy_modified(ret_type=ret_type)
875878
self.wrapped_coro_return_types[defn] = defn.type
876879

880+
def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
881+
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
882+
return typ
883+
last_type = get_proper_type(typ.arg_types[-1])
884+
if not isinstance(last_type, UnpackType):
885+
return typ
886+
last_type = get_proper_type(last_type.type)
887+
if not isinstance(last_type, TypedDictType):
888+
self.fail("Unpack item in ** argument must be a TypedDict", defn)
889+
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
890+
return typ.copy_modified(arg_types=new_arg_types)
891+
overlap = set(typ.arg_names) & set(last_type.items)
892+
# It is OK for TypedDict to have a key named 'kwargs'.
893+
overlap.discard(typ.arg_names[-1])
894+
if overlap:
895+
overlapped = ", ".join([f'"{name}"' for name in overlap])
896+
self.fail(f"Overlap between argument names and ** TypedDict items: {overlapped}", defn)
897+
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
898+
return typ.copy_modified(arg_types=new_arg_types)
899+
# OK, everything looks right now, mark the callable type as using unpack.
900+
new_arg_types = typ.arg_types[:-1] + [last_type]
901+
return typ.copy_modified(arg_types=new_arg_types, unpack_kwargs=True)
902+
877903
def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
878904
"""Check basic signature validity and tweak annotation of self/cls argument."""
879905
# Only non-static methods are special.

mypy/subtypes.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
Instance,
3939
LiteralType,
4040
NoneType,
41+
NormalizedCallableType,
4142
Overloaded,
4243
Parameters,
4344
ParamSpecType,
@@ -591,8 +592,10 @@ def visit_unpack_type(self, left: UnpackType) -> bool:
591592
return False
592593

593594
def visit_parameters(self, left: Parameters) -> bool:
594-
right = self.right
595-
if isinstance(right, Parameters) or isinstance(right, CallableType):
595+
if isinstance(self.right, Parameters) or isinstance(self.right, CallableType):
596+
right = self.right
597+
if isinstance(right, CallableType):
598+
right = right.with_unpacked_kwargs()
596599
return are_parameters_compatible(
597600
left,
598601
right,
@@ -636,7 +639,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
636639
elif isinstance(right, Parameters):
637640
# this doesn't check return types.... but is needed for is_equivalent
638641
return are_parameters_compatible(
639-
left,
642+
left.with_unpacked_kwargs(),
640643
right,
641644
is_compat=self._is_subtype,
642645
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
@@ -1213,6 +1216,10 @@ def g(x: int) -> int: ...
12131216
If the 'some_check' function is also symmetric, the two calls would be equivalent
12141217
whether or not we check the args covariantly.
12151218
"""
1219+
# Normalize both types before comparing them.
1220+
left = left.with_unpacked_kwargs()
1221+
right = right.with_unpacked_kwargs()
1222+
12161223
if is_compat_return is None:
12171224
is_compat_return = is_compat
12181225

@@ -1277,8 +1284,8 @@ def g(x: int) -> int: ...
12771284

12781285

12791286
def are_parameters_compatible(
1280-
left: Parameters | CallableType,
1281-
right: Parameters | CallableType,
1287+
left: Parameters | NormalizedCallableType,
1288+
right: Parameters | NormalizedCallableType,
12821289
*,
12831290
is_compat: Callable[[Type, Type], bool],
12841291
ignore_pos_arg_names: bool = False,
@@ -1499,11 +1506,11 @@ def new_is_compat(left: Type, right: Type) -> bool:
14991506

15001507

15011508
def unify_generic_callable(
1502-
type: CallableType,
1503-
target: CallableType,
1509+
type: NormalizedCallableType,
1510+
target: NormalizedCallableType,
15041511
ignore_return: bool,
15051512
return_constraint_direction: int | None = None,
1506-
) -> CallableType | None:
1513+
) -> NormalizedCallableType | None:
15071514
"""Try to unify a generic callable type with another callable type.
15081515
15091516
Return unified CallableType if successful; otherwise, return None.
@@ -1540,7 +1547,7 @@ def report(*args: Any) -> None:
15401547
)
15411548
if had_errors:
15421549
return None
1543-
return applied
1550+
return cast(NormalizedCallableType, applied)
15441551

15451552

15461553
def try_restrict_literal_union(t: UnionType, s: Type) -> list[Type] | None:

mypy/typeanal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
538538
elif fullname in ("typing.Unpack", "typing_extensions.Unpack"):
539539
# We don't want people to try to use this yet.
540540
if not self.options.enable_incomplete_features:
541-
self.fail('"Unpack" is not supported by mypy yet', t)
541+
self.fail('"Unpack" is not supported yet, use --enable-incomplete-features', t)
542542
return AnyType(TypeOfAny.from_error)
543543
return UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column)
544544
return None

0 commit comments

Comments
 (0)