Skip to content

Commit fa84534

Browse files
Basic support for decorated overloads (#15898)
Fixes #15737 Fixes #12844 Fixes #12716 My goal was to fix the `ParamSpec` issues, but it turns out decorated overloads were not supported at all. Namely: * Decorators on overload items were ignored, caller would see original undecorated item types * Overload item overlap checks were performed for original types, while arguably we should use decorated types * Overload items completeness w.r.t. to implementation was checked with decorated implementation, and undecorated items Here I add basic support using same logic as for regular decorated functions: initially set type to `None` and defer callers until definition is type-checked. Note this results in few more `Cannot determine type` in case of other errors, but I think it is fine. Note I also add special-casing for "inline" applications of generic functions to overload arguments. This use case was mentioned few times alongside overloads. The general fix would be tricky, and my special-casing should cover typical use cases. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b3d0937 commit fa84534

9 files changed

+206
-45
lines changed

mypy/checker.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -636,13 +636,30 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
636636
self.visit_decorator(defn.items[0])
637637
for fdef in defn.items:
638638
assert isinstance(fdef, Decorator)
639-
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
639+
if defn.is_property:
640+
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
641+
else:
642+
# Perform full check for real overloads to infer type of all decorated
643+
# overload variants.
644+
self.visit_decorator_inner(fdef, allow_empty=True)
640645
if fdef.func.abstract_status in (IS_ABSTRACT, IMPLICITLY_ABSTRACT):
641646
num_abstract += 1
642647
if num_abstract not in (0, len(defn.items)):
643648
self.fail(message_registry.INCONSISTENT_ABSTRACT_OVERLOAD, defn)
644649
if defn.impl:
645650
defn.impl.accept(self)
651+
if not defn.is_property:
652+
self.check_overlapping_overloads(defn)
653+
if defn.type is None:
654+
item_types = []
655+
for item in defn.items:
656+
assert isinstance(item, Decorator)
657+
item_type = self.extract_callable_type(item.var.type, item)
658+
if item_type is not None:
659+
item_types.append(item_type)
660+
if item_types:
661+
defn.type = Overloaded(item_types)
662+
# Check override validity after we analyzed current definition.
646663
if defn.info:
647664
found_method_base_classes = self.check_method_override(defn)
648665
if (
@@ -653,10 +670,35 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
653670
self.msg.no_overridable_method(defn.name, defn)
654671
self.check_explicit_override_decorator(defn, found_method_base_classes, defn.impl)
655672
self.check_inplace_operator_method(defn)
656-
if not defn.is_property:
657-
self.check_overlapping_overloads(defn)
658673
return None
659674

675+
def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None:
676+
"""Get type as seen by an overload item caller."""
677+
inner_type = get_proper_type(inner_type)
678+
outer_type: CallableType | None = None
679+
if inner_type is not None and not isinstance(inner_type, AnyType):
680+
if isinstance(inner_type, CallableType):
681+
outer_type = inner_type
682+
elif isinstance(inner_type, Instance):
683+
inner_call = get_proper_type(
684+
analyze_member_access(
685+
name="__call__",
686+
typ=inner_type,
687+
context=ctx,
688+
is_lvalue=False,
689+
is_super=False,
690+
is_operator=True,
691+
msg=self.msg,
692+
original_type=inner_type,
693+
chk=self,
694+
)
695+
)
696+
if isinstance(inner_call, CallableType):
697+
outer_type = inner_call
698+
if outer_type is None:
699+
self.msg.not_callable(inner_type, ctx)
700+
return outer_type
701+
660702
def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
661703
# At this point we should have set the impl already, and all remaining
662704
# items are decorators
@@ -680,40 +722,20 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
680722

681723
# This can happen if we've got an overload with a different
682724
# decorator or if the implementation is untyped -- we gave up on the types.
683-
inner_type = get_proper_type(inner_type)
684-
if inner_type is not None and not isinstance(inner_type, AnyType):
685-
if isinstance(inner_type, CallableType):
686-
impl_type = inner_type
687-
elif isinstance(inner_type, Instance):
688-
inner_call = get_proper_type(
689-
analyze_member_access(
690-
name="__call__",
691-
typ=inner_type,
692-
context=defn.impl,
693-
is_lvalue=False,
694-
is_super=False,
695-
is_operator=True,
696-
msg=self.msg,
697-
original_type=inner_type,
698-
chk=self,
699-
)
700-
)
701-
if isinstance(inner_call, CallableType):
702-
impl_type = inner_call
703-
if impl_type is None:
704-
self.msg.not_callable(inner_type, defn.impl)
725+
impl_type = self.extract_callable_type(inner_type, defn.impl)
705726

706727
is_descriptor_get = defn.info and defn.name == "__get__"
707728
for i, item in enumerate(defn.items):
708-
# TODO overloads involving decorators
709729
assert isinstance(item, Decorator)
710-
sig1 = self.function_type(item.func)
711-
assert isinstance(sig1, CallableType)
730+
sig1 = self.extract_callable_type(item.var.type, item)
731+
if sig1 is None:
732+
continue
712733

713734
for j, item2 in enumerate(defn.items[i + 1 :]):
714735
assert isinstance(item2, Decorator)
715-
sig2 = self.function_type(item2.func)
716-
assert isinstance(sig2, CallableType)
736+
sig2 = self.extract_callable_type(item2.var.type, item2)
737+
if sig2 is None:
738+
continue
717739

718740
if not are_argument_counts_overlapping(sig1, sig2):
719741
continue
@@ -4751,17 +4773,20 @@ def visit_decorator(self, e: Decorator) -> None:
47514773
e.var.type = AnyType(TypeOfAny.special_form)
47524774
e.var.is_ready = True
47534775
return
4776+
self.visit_decorator_inner(e)
47544777

4778+
def visit_decorator_inner(self, e: Decorator, allow_empty: bool = False) -> None:
47554779
if self.recurse_into_functions:
47564780
with self.tscope.function_scope(e.func):
4757-
self.check_func_item(e.func, name=e.func.name)
4781+
self.check_func_item(e.func, name=e.func.name, allow_empty=allow_empty)
47584782

47594783
# Process decorators from the inside out to determine decorated signature, which
47604784
# may be different from the declared signature.
47614785
sig: Type = self.function_type(e.func)
47624786
for d in reversed(e.decorators):
47634787
if refers_to_fullname(d, OVERLOAD_NAMES):
4764-
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
4788+
if not allow_empty:
4789+
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
47654790
continue
47664791
dec = self.expr_checker.accept(d)
47674792
temp = self.temp_node(sig, context=e)
@@ -4788,6 +4813,8 @@ def visit_decorator(self, e: Decorator) -> None:
47884813
self.msg.fail("Too many arguments for property", e)
47894814
self.check_incompatible_property_override(e)
47904815
# For overloaded functions we already checked override for overload as a whole.
4816+
if allow_empty:
4817+
return
47914818
if e.func.info and not e.func.is_dynamic() and not e.is_overload:
47924819
found_method_base_classes = self.check_method_override(e)
47934820
if (

mypy/checkexpr.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -353,12 +353,13 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
353353
elif isinstance(node, FuncDef):
354354
# Reference to a global function.
355355
result = function_type(node, self.named_type("builtins.function"))
356-
elif isinstance(node, OverloadedFuncDef) and node.type is not None:
357-
# node.type is None when there are multiple definitions of a function
358-
# and it's decorated by something that is not typing.overload
359-
# TODO: use a dummy Overloaded instead of AnyType in this case
360-
# like we do in mypy.types.function_type()?
361-
result = node.type
356+
elif isinstance(node, OverloadedFuncDef):
357+
if node.type is None:
358+
if self.chk.in_checked_function() and node.items:
359+
self.chk.handle_cannot_determine_type(node.name, e)
360+
result = AnyType(TypeOfAny.from_error)
361+
else:
362+
result = node.type
362363
elif isinstance(node, TypeInfo):
363364
# Reference to a type object.
364365
if node.typeddict_type:
@@ -1337,6 +1338,55 @@ def transform_callee_type(
13371338

13381339
return callee
13391340

1341+
def is_generic_decorator_overload_call(
1342+
self, callee_type: CallableType, args: list[Expression]
1343+
) -> Overloaded | None:
1344+
"""Check if this looks like an application of a generic function to overload argument."""
1345+
assert callee_type.variables
1346+
if len(callee_type.arg_types) != 1 or len(args) != 1:
1347+
# TODO: can we handle more general cases?
1348+
return None
1349+
if not isinstance(get_proper_type(callee_type.arg_types[0]), CallableType):
1350+
return None
1351+
if not isinstance(get_proper_type(callee_type.ret_type), CallableType):
1352+
return None
1353+
with self.chk.local_type_map():
1354+
with self.msg.filter_errors():
1355+
arg_type = get_proper_type(self.accept(args[0], type_context=None))
1356+
if isinstance(arg_type, Overloaded):
1357+
return arg_type
1358+
return None
1359+
1360+
def handle_decorator_overload_call(
1361+
self, callee_type: CallableType, overloaded: Overloaded, ctx: Context
1362+
) -> tuple[Type, Type] | None:
1363+
"""Type-check application of a generic callable to an overload.
1364+
1365+
We check call on each individual overload item, and then combine results into a new
1366+
overload. This function should be only used if callee_type takes and returns a Callable.
1367+
"""
1368+
result = []
1369+
inferred_args = []
1370+
for item in overloaded.items:
1371+
arg = TempNode(typ=item)
1372+
with self.msg.filter_errors() as err:
1373+
item_result, inferred_arg = self.check_call(callee_type, [arg], [ARG_POS], ctx)
1374+
if err.has_new_errors():
1375+
# This overload doesn't match.
1376+
continue
1377+
p_item_result = get_proper_type(item_result)
1378+
if not isinstance(p_item_result, CallableType):
1379+
continue
1380+
p_inferred_arg = get_proper_type(inferred_arg)
1381+
if not isinstance(p_inferred_arg, CallableType):
1382+
continue
1383+
inferred_args.append(p_inferred_arg)
1384+
result.append(p_item_result)
1385+
if not result or not inferred_args:
1386+
# None of the overload matched (or overload was initially malformed).
1387+
return None
1388+
return Overloaded(result), Overloaded(inferred_args)
1389+
13401390
def check_call_expr_with_callee_type(
13411391
self,
13421392
callee_type: Type,
@@ -1451,6 +1501,17 @@ def check_call(
14511501
callee = get_proper_type(callee)
14521502

14531503
if isinstance(callee, CallableType):
1504+
if callee.variables:
1505+
overloaded = self.is_generic_decorator_overload_call(callee, args)
1506+
if overloaded is not None:
1507+
# Special casing for inline application of generic callables to overloads.
1508+
# Supporting general case would be tricky, but this should cover 95% of cases.
1509+
overloaded_result = self.handle_decorator_overload_call(
1510+
callee, overloaded, context
1511+
)
1512+
if overloaded_result is not None:
1513+
return overloaded_result
1514+
14541515
return self.check_callable_call(
14551516
callee,
14561517
args,

mypy/checkmember.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,17 @@ def analyze_instance_member_access(
317317
return analyze_var(name, first_item.var, typ, info, mx)
318318
if mx.is_lvalue:
319319
mx.msg.cant_assign_to_method(mx.context)
320-
signature = function_type(method, mx.named_type("builtins.function"))
320+
if not isinstance(method, OverloadedFuncDef):
321+
signature = function_type(method, mx.named_type("builtins.function"))
322+
else:
323+
if method.type is None:
324+
# Overloads may be not ready if they are decorated. Handle this in same
325+
# manner as we would handle a regular decorated function: defer if possible.
326+
if not mx.no_deferral and method.items:
327+
mx.not_ready_callback(method.name, mx.context)
328+
return AnyType(TypeOfAny.special_form)
329+
assert isinstance(method.type, Overloaded)
330+
signature = method.type
321331
signature = freshen_all_functions_type_vars(signature)
322332
if not method.is_static:
323333
if name != "__call__":

mypy/semanal.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,16 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
11531153
elif not non_overload_indexes:
11541154
self.handle_missing_overload_implementation(defn)
11551155

1156-
if types:
1156+
if types and not any(
1157+
# If some overload items are decorated with other decorators, then
1158+
# the overload type will be determined during type checking.
1159+
isinstance(it, Decorator) and len(it.decorators) > 1
1160+
for it in defn.items
1161+
):
1162+
# TODO: should we enforce decorated overloads consistency somehow?
1163+
# Some existing code uses both styles:
1164+
# * Put decorator only on implementation, use "effective" types in overloads
1165+
# * Put decorator everywhere, use "bare" types in overloads.
11571166
defn.type = Overloaded(types)
11581167
defn.type.line = defn.line
11591168

test-data/unit/check-generics.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3062,10 +3062,10 @@ def dec5(f: Callable[[int], T]) -> Callable[[int], List[T]]:
30623062
reveal_type(dec1(lambda x: x)) # N: Revealed type is "def [T] (T`3) -> builtins.list[T`3]"
30633063
reveal_type(dec2(lambda x: x)) # N: Revealed type is "def [S] (S`4) -> builtins.list[S`4]"
30643064
reveal_type(dec3(lambda x: x[0])) # N: Revealed type is "def [S] (S`6) -> S`6"
3065-
reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`8) -> S`8"
3065+
reveal_type(dec4(lambda x: [x])) # N: Revealed type is "def [S] (S`9) -> S`9"
30663066
reveal_type(dec1(lambda x: 1)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
30673067
reveal_type(dec5(lambda x: x)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
3068-
reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`15) -> builtins.list[S`15]"
3068+
reveal_type(dec3(lambda x: x)) # N: Revealed type is "def [S] (S`16) -> builtins.list[S`16]"
30693069
dec4(lambda x: x) # E: Incompatible return value type (got "S", expected "List[object]")
30703070
[builtins fixtures/list.pyi]
30713071

test-data/unit/check-newsemanal.test

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,8 +3207,7 @@ class User:
32073207
self.first_name = value
32083208

32093209
def __init__(self, name: str) -> None:
3210-
self.name = name # E: Cannot assign to a method \
3211-
# E: Incompatible types in assignment (expression has type "str", variable has type "Callable[..., Any]")
3210+
self.name = name # E: Cannot assign to a method
32123211

32133212
[case testNewAnalyzerMemberNameMatchesTypedDict]
32143213
from typing import Union, Any

test-data/unit/check-overloading.test

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6613,3 +6613,30 @@ def struct(__cols: Union[List[S], Tuple[S, ...]]) -> int: ...
66136613
def struct(*cols: Union[S, Union[List[S], Tuple[S, ...]]]) -> int:
66146614
pass
66156615
[builtins fixtures/tuple.pyi]
6616+
6617+
[case testRegularGenericDecoratorOverload]
6618+
from typing import Callable, overload, TypeVar, List
6619+
6620+
S = TypeVar("S")
6621+
T = TypeVar("T")
6622+
def transform(func: Callable[[S], List[T]]) -> Callable[[S], T]: ...
6623+
6624+
@overload
6625+
def foo(x: int) -> List[float]: ...
6626+
@overload
6627+
def foo(x: str) -> List[str]: ...
6628+
def foo(x): ...
6629+
6630+
reveal_type(transform(foo)) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)"
6631+
6632+
@transform
6633+
@overload
6634+
def bar(x: int) -> List[float]: ...
6635+
@transform
6636+
@overload
6637+
def bar(x: str) -> List[str]: ...
6638+
@transform
6639+
def bar(x): ...
6640+
6641+
reveal_type(bar) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)"
6642+
[builtins fixtures/paramspec.pyi]

test-data/unit/check-parameter-specification.test

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,3 +1646,31 @@ def bar(b: B[P]) -> A[Concatenate[int, P]]:
16461646
# N: Got: \
16471647
# N: def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs) -> Any
16481648
[builtins fixtures/paramspec.pyi]
1649+
1650+
[case testParamSpecDecoratorOverload]
1651+
from typing import Callable, overload, TypeVar, List
1652+
from typing_extensions import ParamSpec
1653+
1654+
P = ParamSpec("P")
1655+
T = TypeVar("T")
1656+
def transform(func: Callable[P, List[T]]) -> Callable[P, T]: ...
1657+
1658+
@overload
1659+
def foo(x: int) -> List[float]: ...
1660+
@overload
1661+
def foo(x: str) -> List[str]: ...
1662+
def foo(x): ...
1663+
1664+
reveal_type(transform(foo)) # N: Revealed type is "Overload(def (x: builtins.int) -> builtins.float, def (x: builtins.str) -> builtins.str)"
1665+
1666+
@transform
1667+
@overload
1668+
def bar(x: int) -> List[float]: ...
1669+
@transform
1670+
@overload
1671+
def bar(x: str) -> List[str]: ...
1672+
@transform
1673+
def bar(x): ...
1674+
1675+
reveal_type(bar) # N: Revealed type is "Overload(def (x: builtins.int) -> builtins.float, def (x: builtins.str) -> builtins.str)"
1676+
[builtins fixtures/paramspec.pyi]

test-data/unit/lib-stub/functools.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generic, TypeVar, Callable, Any, Mapping
1+
from typing import Generic, TypeVar, Callable, Any, Mapping, overload
22

33
_T = TypeVar("_T")
44

0 commit comments

Comments
 (0)