Skip to content

Commit 9e520c3

Browse files
authored
Allow TypedDict unpacking in Callable types (#16083)
Fixes #16082 Currently we only allow `Unpack` of a TypedDict when it appears in a function definition. This PR also allows this in `Callable` types, similarly to how we do this for variadic types. Note this still doesn't allow having both variadic unpack and a TypedDict unpack in the same `Callable`. Supporting this is tricky, so let's not so this until people will actually ask for this. FWIW we can always suggest callback protocols for such tricky cases.
1 parent 9a35360 commit 9e520c3

File tree

6 files changed

+39
-6
lines changed

6 files changed

+39
-6
lines changed

mypy/exprtotype.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def expr_to_unanalyzed_type(
196196
elif isinstance(expr, EllipsisExpr):
197197
return EllipsisType(expr.line)
198198
elif allow_unpack and isinstance(expr, StarExpr):
199-
return UnpackType(expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax))
199+
return UnpackType(
200+
expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax), from_star_syntax=True
201+
)
200202
else:
201203
raise TypeTranslationError()

mypy/fastparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2041,7 +2041,7 @@ def visit_Attribute(self, n: Attribute) -> Type:
20412041

20422042
# Used for Callable[[X *Ys, Z], R]
20432043
def visit_Starred(self, n: ast3.Starred) -> Type:
2044-
return UnpackType(self.visit(n.value))
2044+
return UnpackType(self.visit(n.value), from_star_syntax=True)
20452045

20462046
# List(expr* elts, expr_context ctx)
20472047
def visit_List(self, n: ast3.List) -> Type:

mypy/semanal_typeargs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ def visit_unpack_type(self, typ: UnpackType) -> None:
214214
# Avoid extra errors if there were some errors already. Also interpret plain Any
215215
# as tuple[Any, ...] (this is better for the code in type checker).
216216
self.fail(
217-
message_registry.INVALID_UNPACK.format(format_type(proper_type, self.options)), typ
217+
message_registry.INVALID_UNPACK.format(format_type(proper_type, self.options)),
218+
typ.type,
219+
code=codes.VALID_TYPE,
218220
)
219221
typ.type = self.named_type("builtins.tuple", [AnyType(TypeOfAny.from_error)])
220222

mypy/typeanal.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,14 +961,15 @@ def visit_unpack_type(self, t: UnpackType) -> Type:
961961
if not self.allow_unpack:
962962
self.fail(message_registry.INVALID_UNPACK_POSITION, t.type, code=codes.VALID_TYPE)
963963
return AnyType(TypeOfAny.from_error)
964-
return UnpackType(self.anal_type(t.type))
964+
return UnpackType(self.anal_type(t.type), from_star_syntax=t.from_star_syntax)
965965

966966
def visit_parameters(self, t: Parameters) -> Type:
967967
raise NotImplementedError("ParamSpec literals cannot have unbound TypeVars")
968968

969969
def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
970970
# Every Callable can bind its own type variables, if they're not in the outer scope
971971
with self.tvar_scope_frame():
972+
unpacked_kwargs = False
972973
if self.defining_alias:
973974
variables = t.variables
974975
else:
@@ -996,6 +997,15 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
996997
)
997998
validated_args.append(AnyType(TypeOfAny.from_error))
998999
else:
1000+
if nested and isinstance(at, UnpackType) and i == star_index:
1001+
# TODO: it would be better to avoid this get_proper_type() call.
1002+
p_at = get_proper_type(at.type)
1003+
if isinstance(p_at, TypedDictType) and not at.from_star_syntax:
1004+
# Automatically detect Unpack[Foo] in Callable as backwards
1005+
# compatible syntax for **Foo, if Foo is a TypedDict.
1006+
at = p_at
1007+
arg_kinds[i] = ARG_STAR2
1008+
unpacked_kwargs = True
9991009
validated_args.append(at)
10001010
arg_types = validated_args
10011011
# If there were multiple (invalid) unpacks, the arg types list will become shorter,
@@ -1013,6 +1023,7 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
10131023
fallback=(t.fallback if t.fallback.type else self.named_type("builtins.function")),
10141024
variables=self.anal_var_defs(variables),
10151025
type_guard=special,
1026+
unpack_kwargs=unpacked_kwargs,
10161027
)
10171028
return ret
10181029

mypy/types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,11 +1053,14 @@ class UnpackType(ProperType):
10531053
wild west, technically anything can be present in the wrapped type.
10541054
"""
10551055

1056-
__slots__ = ["type"]
1056+
__slots__ = ["type", "from_star_syntax"]
10571057

1058-
def __init__(self, typ: Type, line: int = -1, column: int = -1) -> None:
1058+
def __init__(
1059+
self, typ: Type, line: int = -1, column: int = -1, from_star_syntax: bool = False
1060+
) -> None:
10591061
super().__init__(line, column)
10601062
self.type = typ
1063+
self.from_star_syntax = from_star_syntax
10611064

10621065
def accept(self, visitor: TypeVisitor[T]) -> T:
10631066
return visitor.visit_unpack_type(self)

test-data/unit/check-varargs.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,3 +1079,18 @@ class C:
10791079
class D:
10801080
def __init__(self, **kwds: Unpack[int, str]) -> None: ... # E: Unpack[...] requires exactly one type argument
10811081
[builtins fixtures/dict.pyi]
1082+
1083+
[case testUnpackInCallableType]
1084+
from typing import Callable
1085+
from typing_extensions import Unpack, TypedDict
1086+
1087+
class TD(TypedDict):
1088+
key: str
1089+
value: str
1090+
1091+
foo: Callable[[Unpack[TD]], None]
1092+
foo(key="yes", value=42) # E: Argument "value" has incompatible type "int"; expected "str"
1093+
foo(key="yes", value="ok")
1094+
1095+
bad: Callable[[*TD], None] # E: "TD" cannot be unpacked (must be tuple or TypeVarTuple)
1096+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)