diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index fdc0f94b3997..ec8b12920ca3 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3020,6 +3020,10 @@ def union_overload_result( # Step 4: Split the first remaining union type in arguments into items and # try to match each item individually (recursive). first_union = get_proper_type(arg_types[idx]) + if not isinstance(first_union, UnionType): + assert isinstance(first_union, Instance) + # enum or bool + first_union = self.split_into_literals(first_union) assert isinstance(first_union, UnionType) res_items = [] for item in first_union.relevant_items(): @@ -3054,7 +3058,30 @@ def union_overload_result( def real_union(self, typ: Type) -> bool: typ = get_proper_type(typ) - return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1 + return ( + isinstance(typ, UnionType) + and len(typ.relevant_items()) > 1 + or isinstance(typ, Instance) + and (typ.type.is_enum or typ.type.fullname == "builtins.bool") + ) + + def split_into_literals(self, typ: Instance) -> UnionType: + """Represent a finite type as a union of literals. + + When trying overloads, enum and bool can fail to match to union of overloads. + Convert them to union representation to try harder. + """ + + if typ.type.fullname == "builtins.bool": + return UnionType( + [ + LiteralType(True, self.chk.named_type("builtins.bool")), + LiteralType(False, self.chk.named_type("builtins.bool")), + ] + ) + if typ.type.is_enum: + return UnionType([LiteralType(name, typ) for name in typ.get_enum_values()]) + raise NotImplementedError("Only bool and enum types can be split into union.") @contextmanager def type_overrides_set( diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index 24292bce3e21..80092015bd4e 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -3050,10 +3050,10 @@ main:15: error: Unsupported left operand type for >= ("NoCmp") [case testAttrsIncrementalDunder] from a import A reveal_type(A) # N: Revealed type is "def (a: builtins.int) -> a.A" -reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool" -reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" -reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" -reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" +reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" +reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool" A(1) < A(2) A(1) <= A(2) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 48d5996b226f..512d2aba4bb1 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6750,3 +6750,42 @@ def foo(x: object) -> str: ... def bar(x: int) -> int: ... @overload def bar(x: Any) -> str: ... + +[case testOverloadFiniteTypeBoolMatches] +from typing import Union, overload +from typing_extensions import Literal + +@overload +def foo(a: Literal[True]) -> int: ... + +@overload +def foo(a: Literal[False]) -> str: ... + +def foo(a: bool) -> Union[int, str]: ... + +a: bool +reveal_type(foo(a)) # N: Revealed type is "Union[builtins.int, builtins.str]" + +[builtins fixtures/tuple.pyi] + +[case testOverloadFiniteTypeEnumMatches] +from enum import Enum +from typing import Union, overload +from typing_extensions import Literal + +class A(Enum): + a = 1 + b = 2 + +@overload +def foo(a: Literal[A.a]) -> int: ... + +@overload +def foo(a: Literal[A.b]) -> str: ... + +def foo(a: A) -> Union[int, str]: ... + +a: A +reveal_type(foo(a)) # N: Revealed type is "Union[builtins.int, builtins.str]" + +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-plugin-attrs.test b/test-data/unit/check-plugin-attrs.test index b96c00730a74..fc664cbd1264 100644 --- a/test-data/unit/check-plugin-attrs.test +++ b/test-data/unit/check-plugin-attrs.test @@ -185,10 +185,10 @@ from attr import attrib, attrs class A: a: int reveal_type(A) # N: Revealed type is "def (a: builtins.int) -> __main__.A" -reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`3, other: _AT`3) -> builtins.bool" -reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" -reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" -reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +reveal_type(A.__lt__) # N: Revealed type is "def [_AT] (self: _AT`4, other: _AT`4) -> builtins.bool" +reveal_type(A.__le__) # N: Revealed type is "def [_AT] (self: _AT`5, other: _AT`5) -> builtins.bool" +reveal_type(A.__gt__) # N: Revealed type is "def [_AT] (self: _AT`6, other: _AT`6) -> builtins.bool" +reveal_type(A.__ge__) # N: Revealed type is "def [_AT] (self: _AT`7, other: _AT`7) -> builtins.bool" A(1) < A(2) A(1) <= A(2)