Skip to content

Represent finite types as union of literals for checking overloads #17437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3020,6 +3020,10 @@ def union_overload_result(
# Step 4: Split the first remaining union type in arguments into items and
# try to match each item individually (recursive).
first_union = get_proper_type(arg_types[idx])
if not isinstance(first_union, UnionType):
assert isinstance(first_union, Instance)
# enum or bool
first_union = self.split_into_literals(first_union)
assert isinstance(first_union, UnionType)
res_items = []
for item in first_union.relevant_items():
Expand Down Expand Up @@ -3054,7 +3058,30 @@ def union_overload_result(

def real_union(self, typ: Type) -> bool:
typ = get_proper_type(typ)
return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1
return (
isinstance(typ, UnionType)
and len(typ.relevant_items()) > 1
or isinstance(typ, Instance)
and (typ.type.is_enum or typ.type.fullname == "builtins.bool")
)

def split_into_literals(self, typ: Instance) -> UnionType:
"""Represent a finite type as a union of literals.

When trying overloads, enum and bool can fail to match to union of overloads.
Convert them to union representation to try harder.
"""

if typ.type.fullname == "builtins.bool":
return UnionType(
[
LiteralType(True, self.chk.named_type("builtins.bool")),
LiteralType(False, self.chk.named_type("builtins.bool")),
]
)
if typ.type.is_enum:
return UnionType([LiteralType(name, typ) for name in typ.get_enum_values()])
raise NotImplementedError("Only bool and enum types can be split into union.")

@contextmanager
def type_overrides_set(
Expand Down
8 changes: 4 additions & 4 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -3050,10 +3050,10 @@ main:15: error: Unsupported left operand type for >= ("NoCmp")
[case testAttrsIncrementalDunder]
from a import A
reveal_type(A) # N: Revealed type is "def (a: builtins.int) -> a.A"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool"
reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool"

A(1) < A(2)
A(1) <= A(2)
Expand Down
39 changes: 39 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6750,3 +6750,42 @@ def foo(x: object) -> str: ...
def bar(x: int) -> int: ...
@overload
def bar(x: Any) -> str: ...

[case testOverloadFiniteTypeBoolMatches]
from typing import Union, overload
from typing_extensions import Literal

@overload
def foo(a: Literal[True]) -> int: ...

@overload
def foo(a: Literal[False]) -> str: ...

def foo(a: bool) -> Union[int, str]: ...

a: bool
reveal_type(foo(a)) # N: Revealed type is "Union[builtins.int, builtins.str]"

[builtins fixtures/tuple.pyi]

[case testOverloadFiniteTypeEnumMatches]
from enum import Enum
from typing import Union, overload
from typing_extensions import Literal

class A(Enum):
a = 1
b = 2

@overload
def foo(a: Literal[A.a]) -> int: ...

@overload
def foo(a: Literal[A.b]) -> str: ...

def foo(a: A) -> Union[int, str]: ...

a: A
reveal_type(foo(a)) # N: Revealed type is "Union[builtins.int, builtins.str]"

[builtins fixtures/tuple.pyi]
8 changes: 4 additions & 4 deletions test-data/unit/check-plugin-attrs.test
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ from attr import attrib, attrs
class A:
a: int
reveal_type(A) # N: Revealed type is "def (a: builtins.int) -> __main__.A"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool"
reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool"
reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool"
reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool"
reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool"

A(1) < A(2)
A(1) <= A(2)
Expand Down
Loading