diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 280e9a35d537..0a6c955311c5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3841,8 +3841,10 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F allow_none_return=allow_none_return) # Analyze the right branch using full type context and store the type - full_context_else_type = self.analyze_cond_branch(else_map, e.else_expr, context=ctx, - allow_none_return=allow_none_return) + full_context_else_type = self.analyze_cond_branch(else_map, e.else_expr, + context=ctx, + allow_none_return=allow_none_return, + is_else=True) if not mypy.checker.is_valid_inferred_type(if_type): # Analyze the right branch disregarding the left branch. else_type = full_context_else_type @@ -3860,7 +3862,8 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F # Analyze the right branch in the context of the left # branch's type. else_type = self.analyze_cond_branch(else_map, e.else_expr, context=if_type, - allow_none_return=allow_none_return) + allow_none_return=allow_none_return, + is_else=True) # Only create a union type if the type context is a union, to be mostly # compatible with older mypy versions where we always did a join. @@ -3875,20 +3878,32 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F def analyze_cond_branch(self, map: Optional[Dict[Expression, Type]], node: Expression, context: Optional[Type], - allow_none_return: bool = False) -> Type: + allow_none_return: bool = False, + is_else: bool = False) -> Type: # We need to be have the correct amount of binder frames. - # Sometimes it can be missing for unreachable parts. + # Sometimes it can be missing for unreachable left or right parts. with ( self.chk.binder.frame_context(can_skip=True, fall_through=0) if len(self.chk.binder.frames) > 1 else self.chk.binder.top_frame_context() ): + if map is not None: + self.chk.push_type_map(map) + if is_else and context is not None and isinstance(node, CallExpr): + # When calling a function on the else part, + # we can face a generic function with multiple type vars. + # When inferecing it, `context` might be used instead of real args. + # Usually, we don't want that. + # https://github.com/python/mypy/issues/11049 + with self.msg.disable_errors(): + call_type = self.accept(node) + if not is_subtype(call_type, context, ignore_type_params=True): + context = None if map is None: # We still need to type check node, in case we want to # process it for isinstance checks later self.accept(node, type_context=context, allow_none_return=allow_none_return) return UninhabitedType() - self.chk.push_type_map(map) return self.accept(node, type_context=context, allow_none_return=allow_none_return) def visit_backquote_expr(self, e: BackquoteExpr) -> Type: diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 80bc40b6ca98..4170bd630b09 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -1975,7 +1975,7 @@ T = TypeVar('T') class A: def f(self) -> None: - self.g() # E: Too few arguments for "g" of "A" + self.g() # E: Too few arguments for "g" of "A" self.g(1) @dec def g(self, x: str) -> None: pass @@ -2246,6 +2246,30 @@ a = set() if f() else {0} a() # E: "Set[int]" not callable [builtins fixtures/set.pyi] +[case testUnificationEmptySetRight] +def f(): pass +a = {0} if f() else set() +a() # E: "Set[int]" not callable +[builtins fixtures/set.pyi] + +[case testUnificationEmptyCustomSetLeft] +from typing import Set, TypeVar +T = TypeVar('T') +class customset(Set[T]): pass +def f(): pass +a = customset() if f() else {1} +a() # E: "Set[int]" not callable +[builtins fixtures/set.pyi] + +[case testUnificationEmptyCustomSetRight] +from typing import Set, TypeVar +T = TypeVar('T') +class customset(Set[T]): pass +def f(): pass +a = {0} if f() else customset() +a() # E: "Set[int]" not callable +[builtins fixtures/set.pyi] + [case testUnificationEmptyDictLeft] def f(): pass a = {} if f() else {0: 0} @@ -2270,6 +2294,58 @@ a = {0: [0]} if f() else {0: []} a() # E: "Dict[int, List[int]]" not callable [builtins fixtures/dict.pyi] +[case testConditionalInferenceGenericFunctionRight] +from typing import TypeVar, Union + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + +def foo(a: T1, b: T2) -> Union[T1, T2]: pass +x: bool + +reveal_type(1 if x else foo(1, "s")) # N: Revealed type is "Union[builtins.int*, builtins.str*]" +reveal_type("a" if x else foo(1, "s")) # N: Revealed type is "Union[builtins.int*, builtins.str*]" +reveal_type(1 if x else foo("s", 1)) # N: Revealed type is "Union[builtins.str*, builtins.int*]" +reveal_type("a" if x else foo("s", 1)) # N: Revealed type is "Union[builtins.str*, builtins.int*]" +[builtins fixtures/bool.pyi] + +[case testConditionalInferenceGenericFunctionLeft] +from typing import TypeVar, Union + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + +def foo(a: T1, b: T2) -> Union[T1, T2]: pass +x: bool + +reveal_type(foo(1, "s") if x else 1) # N: Revealed type is "Union[builtins.int*, builtins.str*]" +reveal_type(foo(1, "s") if x else "a") # N: Revealed type is "Union[builtins.int*, builtins.str*]" +reveal_type(foo("s", 1) if x else 1) # N: Revealed type is "Union[builtins.str*, builtins.int*]" +reveal_type(foo("s", 1) if x else "a") # N: Revealed type is "Union[builtins.str*, builtins.int*]" +[builtins fixtures/bool.pyi] + +[case testConditionalInferenceSelfNarrowingRight] +from typing import Optional + +class C: + x: Optional[int] + def check(self) -> Optional[int]: + return None if self.x is None else self.x.conjugate() + +reveal_type(C().check()) # N: Revealed type is "Union[builtins.int, None]" +[builtins fixtures/bool.pyi] + +[case testConditionalInferenceSelfNarrowingLeft] +from typing import Optional + +class C: + x: Optional[int] + def check(self) -> Optional[int]: + return self.x.conjugate() if self.x is not None else None + +reveal_type(C().check()) # N: Revealed type is "Union[builtins.int, None]" +[builtins fixtures/bool.pyi] + [case testMisguidedSetItem] from typing import Generic, Sequence, TypeVar T = TypeVar('T') diff --git a/test-data/unit/fixtures/bool.pyi b/test-data/unit/fixtures/bool.pyi index ca2564dabafd..a82a45f06e6a 100644 --- a/test-data/unit/fixtures/bool.pyi +++ b/test-data/unit/fixtures/bool.pyi @@ -10,7 +10,8 @@ class object: class type: pass class tuple(Generic[T]): pass class function: pass -class int: pass +class int: + def conjugate(self) -> int: pass class bool(int): pass class float: pass class str: pass