Skip to content

Commit 0777c10

Browse files
authored
Add support for conditionally defined overloads (#10712)
### Description This PR allows users to define overloads conditionally, e.g., based on the Python version. At the moment this is only possible if all overloads are contained in the same block which requires duplications. ```py from typing import overload, Any import sys class A: ... class B: ... if sys.version_info >= (3, 9): class C: ... @overload def func(g: int) -> A: ... @overload def func(g: bytes) -> B: ... if sys.version_info >= (3, 9): @overload def func(g: str) -> C: ... def func(g: Any) -> Any: ... ``` Closes #9744 ## Test Plan Unit tests have been added. ## Limitations Only `if` is supported. Support for `elif` and `else` might be added in the future. However, I believe that the single if as shown in the example is also the most common use case. The change itself is fully backwards compatible, i.e. the current workaround (see below) will continue to function as expected. ~~**Update**: Seems like support for `elif` and `else` is required for the tests to pass.~~ **Update**: Added support for `elif` and `else`. ## Current workaround ```py from typing import overload, Any import sys class A: ... class B: ... if sys.version_info >= (3, 9): class C: ... if sys.version_info >= (3, 9): @overload def func(g: int) -> A: ... @overload def func(g: bytes) -> B: ... @overload def func(g: str) -> C: ... def func(g: Any) -> Any: ... else: @overload def func(g: int) -> A: ... @overload def func(g: bytes) -> B: ... def func(g: Any) -> Any: ... ```
1 parent 68b3b27 commit 0777c10

File tree

3 files changed

+1185
-3
lines changed

3 files changed

+1185
-3
lines changed

docs/source/more_types.rst

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,114 @@ with ``Union[int, slice]`` and ``Union[T, Sequence]``.
581581
to returning ``Any`` only if the input arguments also contain ``Any``.
582582

583583

584+
Conditional overloads
585+
---------------------
586+
587+
Sometimes it is useful to define overloads conditionally.
588+
Common use cases include types that are unavailable at runtime or that
589+
only exist in a certain Python version. All existing overload rules still apply.
590+
For example, there must be at least two overloads.
591+
592+
.. note::
593+
594+
Mypy can only infer a limited number of conditions.
595+
Supported ones currently include :py:data:`~typing.TYPE_CHECKING`, ``MYPY``,
596+
:ref:`version_and_platform_checks`, and :option:`--always-true <mypy --always-true>`
597+
and :option:`--always-false <mypy --always-false>` values.
598+
599+
.. code-block:: python
600+
601+
from typing import TYPE_CHECKING, Any, overload
602+
603+
if TYPE_CHECKING:
604+
class A: ...
605+
class B: ...
606+
607+
608+
if TYPE_CHECKING:
609+
@overload
610+
def func(var: A) -> A: ...
611+
612+
@overload
613+
def func(var: B) -> B: ...
614+
615+
def func(var: Any) -> Any:
616+
return var
617+
618+
619+
reveal_type(func(A())) # Revealed type is "A"
620+
621+
.. code-block:: python
622+
623+
# flags: --python-version 3.10
624+
import sys
625+
from typing import Any, overload
626+
627+
class A: ...
628+
class B: ...
629+
class C: ...
630+
class D: ...
631+
632+
633+
if sys.version_info < (3, 7):
634+
@overload
635+
def func(var: A) -> A: ...
636+
637+
elif sys.version_info >= (3, 10):
638+
@overload
639+
def func(var: B) -> B: ...
640+
641+
else:
642+
@overload
643+
def func(var: C) -> C: ...
644+
645+
@overload
646+
def func(var: D) -> D: ...
647+
648+
def func(var: Any) -> Any:
649+
return var
650+
651+
652+
reveal_type(func(B())) # Revealed type is "B"
653+
reveal_type(func(C())) # No overload variant of "func" matches argument type "C"
654+
# Possible overload variants:
655+
# def func(var: B) -> B
656+
# def func(var: D) -> D
657+
# Revealed type is "Any"
658+
659+
660+
.. note::
661+
662+
In the last example, mypy is executed with
663+
:option:`--python-version 3.10 <mypy --python-version>`.
664+
Therefore, the condition ``sys.version_info >= (3, 10)`` will match and
665+
the overload for ``B`` will be added.
666+
The overloads for ``A`` and ``C`` are ignored!
667+
The overload for ``D`` is not defined conditionally and thus is also added.
668+
669+
When mypy cannot infer a condition to be always True or always False, an error is emitted.
670+
671+
.. code-block:: python
672+
673+
from typing import Any, overload
674+
675+
class A: ...
676+
class B: ...
677+
678+
679+
def g(bool_var: bool) -> None:
680+
if bool_var: # Condition can't be inferred, unable to merge overloads
681+
@overload
682+
def func(var: A) -> A: ...
683+
684+
@overload
685+
def func(var: B) -> B: ...
686+
687+
def func(var: Any) -> Any: ...
688+
689+
reveal_type(func(A())) # Revealed type is "Any"
690+
691+
584692
.. _advanced_self:
585693

586694
Advanced uses of self-types

mypy/fastparse.py

Lines changed: 196 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from mypy import message_registry, errorcodes as codes
4545
from mypy.errors import Errors
4646
from mypy.options import Options
47-
from mypy.reachability import mark_block_unreachable
47+
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
4848
from mypy.util import bytes_to_human_readable_repr
4949

5050
try:
@@ -344,9 +344,19 @@ def fail(self,
344344
msg: str,
345345
line: int,
346346
column: int,
347-
blocker: bool = True) -> None:
347+
blocker: bool = True,
348+
code: codes.ErrorCode = codes.SYNTAX) -> None:
348349
if blocker or not self.options.ignore_errors:
349-
self.errors.report(line, column, msg, blocker=blocker, code=codes.SYNTAX)
350+
self.errors.report(line, column, msg, blocker=blocker, code=code)
351+
352+
def fail_merge_overload(self, node: IfStmt) -> None:
353+
self.fail(
354+
"Condition can't be inferred, unable to merge overloads",
355+
line=node.line,
356+
column=node.column,
357+
blocker=False,
358+
code=codes.MISC,
359+
)
350360

351361
def visit(self, node: Optional[AST]) -> Any:
352362
if node is None:
@@ -476,12 +486,93 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
476486
ret: List[Statement] = []
477487
current_overload: List[OverloadPart] = []
478488
current_overload_name: Optional[str] = None
489+
last_if_stmt: Optional[IfStmt] = None
490+
last_if_overload: Optional[Union[Decorator, FuncDef, OverloadedFuncDef]] = None
491+
last_if_stmt_overload_name: Optional[str] = None
492+
last_if_unknown_truth_value: Optional[IfStmt] = None
493+
skipped_if_stmts: List[IfStmt] = []
479494
for stmt in stmts:
495+
if_overload_name: Optional[str] = None
496+
if_block_with_overload: Optional[Block] = None
497+
if_unknown_truth_value: Optional[IfStmt] = None
498+
if (
499+
isinstance(stmt, IfStmt)
500+
and len(stmt.body[0].body) == 1
501+
and (
502+
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
503+
or current_overload_name is not None
504+
and isinstance(stmt.body[0].body[0], FuncDef)
505+
)
506+
):
507+
# Check IfStmt block to determine if function overloads can be merged
508+
if_overload_name = self._check_ifstmt_for_overloads(stmt)
509+
if if_overload_name is not None:
510+
if_block_with_overload, if_unknown_truth_value = \
511+
self._get_executable_if_block_with_overloads(stmt)
512+
480513
if (current_overload_name is not None
481514
and isinstance(stmt, (Decorator, FuncDef))
482515
and stmt.name == current_overload_name):
516+
if last_if_stmt is not None:
517+
skipped_if_stmts.append(last_if_stmt)
518+
if last_if_overload is not None:
519+
# Last stmt was an IfStmt with same overload name
520+
# Add overloads to current_overload
521+
if isinstance(last_if_overload, OverloadedFuncDef):
522+
current_overload.extend(last_if_overload.items)
523+
else:
524+
current_overload.append(last_if_overload)
525+
last_if_stmt, last_if_overload = None, None
526+
if last_if_unknown_truth_value:
527+
self.fail_merge_overload(last_if_unknown_truth_value)
528+
last_if_unknown_truth_value = None
483529
current_overload.append(stmt)
530+
elif (
531+
current_overload_name is not None
532+
and isinstance(stmt, IfStmt)
533+
and if_overload_name == current_overload_name
534+
):
535+
# IfStmt only contains stmts relevant to current_overload.
536+
# Check if stmts are reachable and add them to current_overload,
537+
# otherwise skip IfStmt to allow subsequent overload
538+
# or function definitions.
539+
skipped_if_stmts.append(stmt)
540+
if if_block_with_overload is None:
541+
if if_unknown_truth_value is not None:
542+
self.fail_merge_overload(if_unknown_truth_value)
543+
continue
544+
if last_if_overload is not None:
545+
# Last stmt was an IfStmt with same overload name
546+
# Add overloads to current_overload
547+
if isinstance(last_if_overload, OverloadedFuncDef):
548+
current_overload.extend(last_if_overload.items)
549+
else:
550+
current_overload.append(last_if_overload)
551+
last_if_stmt, last_if_overload = None, None
552+
if isinstance(if_block_with_overload.body[0], OverloadedFuncDef):
553+
current_overload.extend(if_block_with_overload.body[0].items)
554+
else:
555+
current_overload.append(
556+
cast(Union[Decorator, FuncDef], if_block_with_overload.body[0])
557+
)
484558
else:
559+
if last_if_stmt is not None:
560+
ret.append(last_if_stmt)
561+
last_if_stmt_overload_name = current_overload_name
562+
last_if_stmt, last_if_overload = None, None
563+
last_if_unknown_truth_value = None
564+
565+
if current_overload and current_overload_name == last_if_stmt_overload_name:
566+
# Remove last stmt (IfStmt) from ret if the overload names matched
567+
# Only happens if no executable block had been found in IfStmt
568+
skipped_if_stmts.append(cast(IfStmt, ret.pop()))
569+
if current_overload and skipped_if_stmts:
570+
# Add bare IfStmt (without overloads) to ret
571+
# Required for mypy to be able to still check conditions
572+
for if_stmt in skipped_if_stmts:
573+
self._strip_contents_from_if_stmt(if_stmt)
574+
ret.append(if_stmt)
575+
skipped_if_stmts = []
485576
if len(current_overload) == 1:
486577
ret.append(current_overload[0])
487578
elif len(current_overload) > 1:
@@ -495,17 +586,119 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
495586
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
496587
current_overload = [stmt]
497588
current_overload_name = stmt.name
589+
elif (
590+
isinstance(stmt, IfStmt)
591+
and if_overload_name is not None
592+
):
593+
current_overload = []
594+
current_overload_name = if_overload_name
595+
last_if_stmt = stmt
596+
last_if_stmt_overload_name = None
597+
if if_block_with_overload is not None:
598+
last_if_overload = cast(
599+
Union[Decorator, FuncDef, OverloadedFuncDef],
600+
if_block_with_overload.body[0]
601+
)
602+
last_if_unknown_truth_value = if_unknown_truth_value
498603
else:
499604
current_overload = []
500605
current_overload_name = None
501606
ret.append(stmt)
502607

608+
if current_overload and skipped_if_stmts:
609+
# Add bare IfStmt (without overloads) to ret
610+
# Required for mypy to be able to still check conditions
611+
for if_stmt in skipped_if_stmts:
612+
self._strip_contents_from_if_stmt(if_stmt)
613+
ret.append(if_stmt)
503614
if len(current_overload) == 1:
504615
ret.append(current_overload[0])
505616
elif len(current_overload) > 1:
506617
ret.append(OverloadedFuncDef(current_overload))
618+
elif last_if_stmt is not None:
619+
ret.append(last_if_stmt)
507620
return ret
508621

622+
def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
623+
"""Check if IfStmt contains only overloads with the same name.
624+
Return overload_name if found, None otherwise.
625+
"""
626+
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
627+
# Multiple overloads have already been merged as OverloadedFuncDef.
628+
if not (
629+
len(stmt.body[0].body) == 1
630+
and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
631+
):
632+
return None
633+
634+
overload_name = stmt.body[0].body[0].name
635+
if stmt.else_body is None:
636+
return overload_name
637+
638+
if isinstance(stmt.else_body, Block) and len(stmt.else_body.body) == 1:
639+
# For elif: else_body contains an IfStmt itself -> do a recursive check.
640+
if (
641+
isinstance(stmt.else_body.body[0], (Decorator, FuncDef, OverloadedFuncDef))
642+
and stmt.else_body.body[0].name == overload_name
643+
):
644+
return overload_name
645+
if (
646+
isinstance(stmt.else_body.body[0], IfStmt)
647+
and self._check_ifstmt_for_overloads(stmt.else_body.body[0]) == overload_name
648+
):
649+
return overload_name
650+
651+
return None
652+
653+
def _get_executable_if_block_with_overloads(
654+
self, stmt: IfStmt
655+
) -> Tuple[Optional[Block], Optional[IfStmt]]:
656+
"""Return block from IfStmt that will get executed.
657+
658+
Return
659+
0 -> A block if sure that alternative blocks are unreachable.
660+
1 -> An IfStmt if the reachability of it can't be inferred,
661+
i.e. the truth value is unknown.
662+
"""
663+
infer_reachability_of_if_statement(stmt, self.options)
664+
if (
665+
stmt.else_body is None
666+
and stmt.body[0].is_unreachable is True
667+
):
668+
# always False condition with no else
669+
return None, None
670+
if (
671+
stmt.else_body is None
672+
or stmt.body[0].is_unreachable is False
673+
and stmt.else_body.is_unreachable is False
674+
):
675+
# The truth value is unknown, thus not conclusive
676+
return None, stmt
677+
if stmt.else_body.is_unreachable is True:
678+
# else_body will be set unreachable if condition is always True
679+
return stmt.body[0], None
680+
if stmt.body[0].is_unreachable is True:
681+
# body will be set unreachable if condition is always False
682+
# else_body can contain an IfStmt itself (for elif) -> do a recursive check
683+
if isinstance(stmt.else_body.body[0], IfStmt):
684+
return self._get_executable_if_block_with_overloads(stmt.else_body.body[0])
685+
return stmt.else_body, None
686+
return None, stmt
687+
688+
def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None:
689+
"""Remove contents from IfStmt.
690+
691+
Needed to still be able to check the conditions after the contents
692+
have been merged with the surrounding function overloads.
693+
"""
694+
if len(stmt.body) == 1:
695+
stmt.body[0].body = []
696+
if stmt.else_body and len(stmt.else_body.body) == 1:
697+
if isinstance(stmt.else_body.body[0], IfStmt):
698+
self._strip_contents_from_if_stmt(stmt.else_body.body[0])
699+
else:
700+
stmt.else_body.body = []
701+
509702
def in_method_scope(self) -> bool:
510703
return self.class_and_function_stack[-2:] == ['C', 'F']
511704

0 commit comments

Comments
 (0)