Skip to content

Commit 3973981

Browse files
authored
Allow using TypedDict for more precise typing of **kwds (#13471)
Fixes #4441 This uses a different approach than the initial attempt, but I re-used some of the test cases from the older PR. The initial idea was to eagerly expand the signature of the function during semantic analysis, but it didn't work well with fine-grained mode and also mypy in general relies on function definition and its type being consistent (and rewriting `FuncDef` sounds too sketchy). So instead I add a boolean flag to `CallableType` to indicate whether type of `**kwargs` is each item type or the "packed" type. I also add few helpers and safety net in form of a `NewType()`, but in general I am surprised how few places needed normalizing the signatures (because most relevant code paths go through `check_callable_call()` and/or `is_callable_compatible()`). Currently `Unpack[...]` is hidden behind `--enable-incomplete-features`, so this will be too, but IMO this part is 99% complete (you can see even some more exotic use cases like generic TypedDicts and callback protocols in test cases).
1 parent d89b28d commit 3973981

16 files changed

+505
-21
lines changed

mypy/checker.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
730730
# This is to match the direction the implementation's return
731731
# needs to be compatible in.
732732
if impl_type.variables:
733-
impl = unify_generic_callable(
734-
impl_type,
735-
sig1,
733+
impl: CallableType | None = unify_generic_callable(
734+
# Normalize both before unifying
735+
impl_type.with_unpacked_kwargs(),
736+
sig1.with_unpacked_kwargs(),
736737
ignore_return=False,
737738
return_constraint_direction=SUPERTYPE_OF,
738739
)
@@ -1167,7 +1168,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str | None) ->
11671168
# builtins.tuple[T] is typing.Tuple[T, ...]
11681169
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
11691170
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
1170-
if not isinstance(arg_type, ParamSpecType):
1171+
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
11711172
arg_type = self.named_generic_type(
11721173
"builtins.dict", [self.str_type(), arg_type]
11731174
)
@@ -1912,6 +1913,13 @@ def check_override(
19121913

19131914
if fail:
19141915
emitted_msg = False
1916+
1917+
# Normalize signatures, so we get better diagnostics.
1918+
if isinstance(override, (CallableType, Overloaded)):
1919+
override = override.with_unpacked_kwargs()
1920+
if isinstance(original, (CallableType, Overloaded)):
1921+
original = original.with_unpacked_kwargs()
1922+
19151923
if (
19161924
isinstance(override, CallableType)
19171925
and isinstance(original, CallableType)

mypy/checkexpr.py

+4
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,8 @@ def check_callable_call(
13221322
13231323
See the docstring of check_call for more information.
13241324
"""
1325+
# Always unpack **kwargs before checking a call.
1326+
callee = callee.with_unpacked_kwargs()
13251327
if callable_name is None and callee.name:
13261328
callable_name = callee.name
13271329
ret_type = get_proper_type(callee.ret_type)
@@ -2057,6 +2059,8 @@ def check_overload_call(
20572059
context: Context,
20582060
) -> tuple[Type, Type]:
20592061
"""Checks a call to an overloaded function."""
2062+
# Normalize unpacked kwargs before checking the call.
2063+
callee = callee.with_unpacked_kwargs()
20602064
arg_types = self.infer_arg_types_in_empty_context(args)
20612065
# Step 1: Filter call targets to remove ones where the argument counts don't match
20622066
plausible_targets = self.plausible_overload_call_targets(

mypy/constraints.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -763,9 +763,13 @@ def infer_constraints_from_protocol_members(
763763
return res
764764

765765
def visit_callable_type(self, template: CallableType) -> list[Constraint]:
766+
# Normalize callables before matching against each other.
767+
# Note that non-normalized callables can be created in annotations
768+
# using e.g. callback protocols.
769+
template = template.with_unpacked_kwargs()
766770
if isinstance(self.actual, CallableType):
767771
res: list[Constraint] = []
768-
cactual = self.actual
772+
cactual = self.actual.with_unpacked_kwargs()
769773
param_spec = template.param_spec()
770774
if param_spec is None:
771775
# FIX verify argument counts

mypy/join.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from typing import Tuple
6+
57
import mypy.typeops
68
from mypy.maptype import map_instance_to_supertype
79
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
@@ -141,7 +143,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
141143

142144
def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
143145
"""Return a simple least upper bound given the declared type."""
144-
# TODO: check infinite recursion for aliases here.
146+
# TODO: check infinite recursion for aliases here?
145147
declaration = get_proper_type(declaration)
146148
s = get_proper_type(s)
147149
t = get_proper_type(t)
@@ -172,6 +174,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
172174
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
173175
s, t = t, s
174176

177+
# Meets/joins require callable type normalization.
178+
s, t = normalize_callables(s, t)
179+
175180
value = t.accept(TypeJoinVisitor(s))
176181
if declaration is None or is_subtype(value, declaration):
177182
return value
@@ -229,6 +234,9 @@ def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None)
229234
elif isinstance(t, PlaceholderType):
230235
return AnyType(TypeOfAny.from_error)
231236

237+
# Meets/joins require callable type normalization.
238+
s, t = normalize_callables(s, t)
239+
232240
# Use a visitor to handle non-trivial cases.
233241
return t.accept(TypeJoinVisitor(s, instance_joiner))
234242

@@ -528,6 +536,14 @@ def is_better(t: Type, s: Type) -> bool:
528536
return False
529537

530538

539+
def normalize_callables(s: ProperType, t: ProperType) -> Tuple[ProperType, ProperType]:
540+
if isinstance(s, (CallableType, Overloaded)):
541+
s = s.with_unpacked_kwargs()
542+
if isinstance(t, (CallableType, Overloaded)):
543+
t = t.with_unpacked_kwargs()
544+
return s, t
545+
546+
531547
def is_similar_callables(t: CallableType, s: CallableType) -> bool:
532548
"""Return True if t and s have identical numbers of
533549
arguments, default arguments and varargs.

mypy/meet.py

+4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def meet_types(s: Type, t: Type) -> ProperType:
7878
return t
7979
if isinstance(s, UnionType) and not isinstance(t, UnionType):
8080
s, t = t, s
81+
82+
# Meets/joins require callable type normalization.
83+
s, t = join.normalize_callables(s, t)
84+
8185
return t.accept(TypeMeetVisitor(s))
8286

8387

mypy/messages.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2392,7 +2392,10 @@ def [T <: int] f(self, x: int, y: T) -> None
23922392
name = tp.arg_names[i]
23932393
if name:
23942394
s += name + ": "
2395-
s += format_type_bare(tp.arg_types[i])
2395+
type_str = format_type_bare(tp.arg_types[i])
2396+
if tp.arg_kinds[i] == ARG_STAR2 and tp.unpack_kwargs:
2397+
type_str = f"Unpack[{type_str}]"
2398+
s += type_str
23962399
if tp.arg_kinds[i].is_optional():
23972400
s += " = ..."
23982401
if (

mypy/semanal.py

+26
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@
263263
TypeVarLikeType,
264264
TypeVarType,
265265
UnboundType,
266+
UnpackType,
266267
get_proper_type,
267268
get_proper_types,
268269
invalid_recursive_alias,
@@ -830,6 +831,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
830831
self.defer(defn)
831832
return
832833
assert isinstance(result, ProperType)
834+
if isinstance(result, CallableType):
835+
result = self.remove_unpack_kwargs(defn, result)
833836
defn.type = result
834837
self.add_type_alias_deps(analyzer.aliases_used)
835838
self.check_function_signature(defn)
@@ -872,6 +875,29 @@ def analyze_func_def(self, defn: FuncDef) -> None:
872875
defn.type = defn.type.copy_modified(ret_type=ret_type)
873876
self.wrapped_coro_return_types[defn] = defn.type
874877

878+
def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
879+
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
880+
return typ
881+
last_type = get_proper_type(typ.arg_types[-1])
882+
if not isinstance(last_type, UnpackType):
883+
return typ
884+
last_type = get_proper_type(last_type.type)
885+
if not isinstance(last_type, TypedDictType):
886+
self.fail("Unpack item in ** argument must be a TypedDict", defn)
887+
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
888+
return typ.copy_modified(arg_types=new_arg_types)
889+
overlap = set(typ.arg_names) & set(last_type.items)
890+
# It is OK for TypedDict to have a key named 'kwargs'.
891+
overlap.discard(typ.arg_names[-1])
892+
if overlap:
893+
overlapped = ", ".join([f'"{name}"' for name in overlap])
894+
self.fail(f"Overlap between argument names and ** TypedDict items: {overlapped}", defn)
895+
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
896+
return typ.copy_modified(arg_types=new_arg_types)
897+
# OK, everything looks right now, mark the callable type as using unpack.
898+
new_arg_types = typ.arg_types[:-1] + [last_type]
899+
return typ.copy_modified(arg_types=new_arg_types, unpack_kwargs=True)
900+
875901
def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
876902
"""Check basic signature validity and tweak annotation of self/cls argument."""
877903
# Only non-static methods are special.

mypy/subtypes.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
Instance,
3939
LiteralType,
4040
NoneType,
41+
NormalizedCallableType,
4142
Overloaded,
4243
Parameters,
4344
ParamSpecType,
@@ -626,8 +627,10 @@ def visit_unpack_type(self, left: UnpackType) -> bool:
626627
return False
627628

628629
def visit_parameters(self, left: Parameters) -> bool:
629-
right = self.right
630-
if isinstance(right, Parameters) or isinstance(right, CallableType):
630+
if isinstance(self.right, Parameters) or isinstance(self.right, CallableType):
631+
right = self.right
632+
if isinstance(right, CallableType):
633+
right = right.with_unpacked_kwargs()
631634
return are_parameters_compatible(
632635
left,
633636
right,
@@ -671,7 +674,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
671674
elif isinstance(right, Parameters):
672675
# this doesn't check return types.... but is needed for is_equivalent
673676
return are_parameters_compatible(
674-
left,
677+
left.with_unpacked_kwargs(),
675678
right,
676679
is_compat=self._is_subtype,
677680
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
@@ -1249,6 +1252,10 @@ def g(x: int) -> int: ...
12491252
If the 'some_check' function is also symmetric, the two calls would be equivalent
12501253
whether or not we check the args covariantly.
12511254
"""
1255+
# Normalize both types before comparing them.
1256+
left = left.with_unpacked_kwargs()
1257+
right = right.with_unpacked_kwargs()
1258+
12521259
if is_compat_return is None:
12531260
is_compat_return = is_compat
12541261

@@ -1313,8 +1320,8 @@ def g(x: int) -> int: ...
13131320

13141321

13151322
def are_parameters_compatible(
1316-
left: Parameters | CallableType,
1317-
right: Parameters | CallableType,
1323+
left: Parameters | NormalizedCallableType,
1324+
right: Parameters | NormalizedCallableType,
13181325
*,
13191326
is_compat: Callable[[Type, Type], bool],
13201327
ignore_pos_arg_names: bool = False,
@@ -1535,11 +1542,11 @@ def new_is_compat(left: Type, right: Type) -> bool:
15351542

15361543

15371544
def unify_generic_callable(
1538-
type: CallableType,
1539-
target: CallableType,
1545+
type: NormalizedCallableType,
1546+
target: NormalizedCallableType,
15401547
ignore_return: bool,
15411548
return_constraint_direction: int | None = None,
1542-
) -> CallableType | None:
1549+
) -> NormalizedCallableType | None:
15431550
"""Try to unify a generic callable type with another callable type.
15441551
15451552
Return unified CallableType if successful; otherwise, return None.
@@ -1576,7 +1583,7 @@ def report(*args: Any) -> None:
15761583
)
15771584
if had_errors:
15781585
return None
1579-
return applied
1586+
return cast(NormalizedCallableType, applied)
15801587

15811588

15821589
def try_restrict_literal_union(t: UnionType, s: Type) -> list[Type] | None:

mypy/typeanal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
538538
elif fullname in ("typing.Unpack", "typing_extensions.Unpack"):
539539
# We don't want people to try to use this yet.
540540
if not self.options.enable_incomplete_features:
541-
self.fail('"Unpack" is not supported by mypy yet', t)
541+
self.fail('"Unpack" is not supported yet, use --enable-incomplete-features', t)
542542
return AnyType(TypeOfAny.from_error)
543543
return UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column)
544544
return None

0 commit comments

Comments
 (0)