@@ -1188,14 +1188,14 @@ def check_overload_call(self,
11881188 # gives a narrower type.
11891189 if unioned_return :
11901190 returns , inferred_types = zip (* unioned_return )
1191- # Note that we use `union_overload_matches ` instead of just returning
1191+ # Note that we use `combine_function_signatures ` instead of just returning
11921192 # a union of inferred callables because for example a call
11931193 # Union[int -> int, str -> str](Union[int, str]) is invalid and
11941194 # we don't want to introduce internal inconsistencies.
11951195 unioned_result = (UnionType .make_simplified_union (list (returns ),
11961196 context .line ,
11971197 context .column ),
1198- self .union_overload_matches (inferred_types ))
1198+ self .combine_function_signatures (inferred_types ))
11991199
12001200 # Step 3: We try checking each branch one-by-one.
12011201 inferred_result = self .infer_overload_return_type (plausible_targets , args , arg_types ,
@@ -1492,8 +1492,8 @@ def type_overrides_set(self, exprs: Sequence[Expression],
14921492 for expr in exprs :
14931493 del self .type_overrides [expr ]
14941494
1495- def union_overload_matches (self , types : Sequence [Type ]) -> Union [AnyType , CallableType ]:
1496- """Accepts a list of overload signatures and attempts to combine them together into a
1495+ def combine_function_signatures (self , types : Sequence [Type ]) -> Union [AnyType , CallableType ]:
1496+ """Accepts a list of function signatures and attempts to combine them together into a
14971497 new CallableType consisting of the union of all of the given arguments and return types.
14981498
14991499 If there is at least one non-callable type, return Any (this can happen if there is
@@ -1507,7 +1507,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
15071507 return callables [0 ]
15081508
15091509 # Note: we are assuming here that if a user uses some TypeVar 'T' in
1510- # two different overloads , they meant for that TypeVar to mean the
1510+ # two different functions , they meant for that TypeVar to mean the
15111511 # same thing.
15121512 #
15131513 # This function will make sure that all instances of that TypeVar 'T'
@@ -1525,7 +1525,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
15251525
15261526 too_complex = False
15271527 for target in callables :
1528- # We fall back to Callable[..., Union[<returns>]] if the overloads do not have
1528+ # We fall back to Callable[..., Union[<returns>]] if the functions do not have
15291529 # the exact same signature. The only exception is if one arg is optional and
15301530 # the other is positional: in that case, we continue unioning (and expect a
15311531 # positional arg).
@@ -1820,19 +1820,12 @@ def check_op_reversible(self,
18201820 left_expr : Expression ,
18211821 right_type : Type ,
18221822 right_expr : Expression ,
1823- context : Context ) -> Tuple [Type , Type ]:
1824- # Note: this kludge exists mostly to maintain compatibility with
1825- # existing error messages. Apparently, if the left-hand-side is a
1826- # union and we have a type mismatch, we print out a special,
1827- # abbreviated error message. (See messages.unsupported_operand_types).
1828- unions_present = isinstance (left_type , UnionType )
1829-
1823+ context : Context ,
1824+ msg : MessageBuilder ) -> Tuple [Type , Type ]:
18301825 def make_local_errors () -> MessageBuilder :
18311826 """Creates a new MessageBuilder object."""
1832- local_errors = self . msg .clean_copy ()
1827+ local_errors = msg .clean_copy ()
18331828 local_errors .disable_count = 0
1834- if unions_present :
1835- local_errors .disable_type_names += 1
18361829 return local_errors
18371830
18381831 def lookup_operator (op_name : str , base_type : Type ) -> Optional [Type ]:
@@ -2006,30 +1999,101 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
20061999 # TODO: Remove this extra case
20072000 return result
20082001
2009- self . msg .add_errors (errors [0 ])
2002+ msg .add_errors (errors [0 ])
20102003 if len (results ) == 1 :
20112004 return results [0 ]
20122005 else :
20132006 error_any = AnyType (TypeOfAny .from_error )
20142007 result = error_any , error_any
20152008 return result
20162009
2017- def check_op (self , method : str , base_type : Type , arg : Expression ,
2018- context : Context ,
2010+ def check_op (self , method : str , base_type : Type ,
2011+ arg : Expression , context : Context ,
20192012 allow_reverse : bool = False ) -> Tuple [Type , Type ]:
20202013 """Type check a binary operation which maps to a method call.
20212014
20222015 Return tuple (result type, inferred operator method type).
20232016 """
20242017
20252018 if allow_reverse :
2026- return self .check_op_reversible (
2027- op_name = method ,
2028- left_type = base_type ,
2029- left_expr = TempNode (base_type ),
2030- right_type = self .accept (arg ),
2031- right_expr = arg ,
2032- context = context )
2019+ left_variants = [base_type ]
2020+ if isinstance (base_type , UnionType ):
2021+ left_variants = [item for item in base_type .relevant_items ()]
2022+ right_type = self .accept (arg )
2023+
2024+ # Step 1: We first try leaving the right arguments alone and destructure
2025+ # just the left ones. (Mypy can sometimes perform some more precise inference
2026+ # if we leave the right operands a union -- see testOperatorWithEmptyListAndSum.
2027+ msg = self .msg .clean_copy ()
2028+ msg .disable_count = 0
2029+ all_results = []
2030+ all_inferred = []
2031+
2032+ for left_possible_type in left_variants :
2033+ result , inferred = self .check_op_reversible (
2034+ op_name = method ,
2035+ left_type = left_possible_type ,
2036+ left_expr = TempNode (left_possible_type ),
2037+ right_type = right_type ,
2038+ right_expr = arg ,
2039+ context = context ,
2040+ msg = msg )
2041+ all_results .append (result )
2042+ all_inferred .append (inferred )
2043+
2044+ if not msg .is_errors ():
2045+ results_final = UnionType .make_simplified_union (all_results )
2046+ inferred_final = UnionType .make_simplified_union (all_inferred )
2047+ return results_final , inferred_final
2048+
2049+ # Step 2: If that fails, we try again but also destructure the right argument.
2050+ # This is also necessary to make certain edge cases work -- see
2051+ # testOperatorDoubleUnionInterwovenUnionAdd, for example.
2052+
2053+ # Note: We want to pass in the original 'arg' for 'left_expr' and 'right_expr'
2054+ # whenever possible so that plugins and similar things can introspect on the original
2055+ # node if possible.
2056+ #
2057+ # We don't do the same for the base expression because it could lead to weird
2058+ # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
2059+ # TODO: Can we use `type_overrides_set()` here?
2060+ right_variants = [(right_type , arg )]
2061+ if isinstance (right_type , UnionType ):
2062+ right_variants = [(item , TempNode (item )) for item in right_type .relevant_items ()]
2063+
2064+ msg = self .msg .clean_copy ()
2065+ msg .disable_count = 0
2066+ all_results = []
2067+ all_inferred = []
2068+
2069+ for left_possible_type in left_variants :
2070+ for right_possible_type , right_expr in right_variants :
2071+ result , inferred = self .check_op_reversible (
2072+ op_name = method ,
2073+ left_type = left_possible_type ,
2074+ left_expr = TempNode (left_possible_type ),
2075+ right_type = right_possible_type ,
2076+ right_expr = right_expr ,
2077+ context = context ,
2078+ msg = msg )
2079+ all_results .append (result )
2080+ all_inferred .append (inferred )
2081+
2082+ if msg .is_errors ():
2083+ self .msg .add_errors (msg )
2084+ if len (left_variants ) >= 2 and len (right_variants ) >= 2 :
2085+ self .msg .warn_both_operands_are_from_unions (context )
2086+ elif len (left_variants ) >= 2 :
2087+ self .msg .warn_operand_was_from_union ("Left" , base_type , context )
2088+ elif len (right_variants ) >= 2 :
2089+ self .msg .warn_operand_was_from_union ("Right" , right_type , context )
2090+
2091+ # See the comment in 'check_overload_call' for more details on why
2092+ # we call 'combine_function_signature' instead of just unioning the inferred
2093+ # callable types.
2094+ results_final = UnionType .make_simplified_union (all_results )
2095+ inferred_final = self .combine_function_signatures (all_inferred )
2096+ return results_final , inferred_final
20332097 else :
20342098 return self .check_op_local_by_name (
20352099 method = method ,
0 commit comments