Skip to content
Merged
116 changes: 90 additions & 26 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,14 +1188,14 @@ def check_overload_call(self,
# gives a narrower type.
if unioned_return:
returns, inferred_types = zip(*unioned_return)
# Note that we use `union_overload_matches` instead of just returning
# Note that we use `combine_function_signatures` instead of just returning
# a union of inferred callables because for example a call
# Union[int -> int, str -> str](Union[int, str]) is invalid and
# we don't want to introduce internal inconsistencies.
unioned_result = (UnionType.make_simplified_union(list(returns),
context.line,
context.column),
self.union_overload_matches(inferred_types))
self.combine_function_signatures(inferred_types))

# Step 3: We try checking each branch one-by-one.
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
Expand Down Expand Up @@ -1492,8 +1492,8 @@ def type_overrides_set(self, exprs: Sequence[Expression],
for expr in exprs:
del self.type_overrides[expr]

def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
"""Accepts a list of overload signatures and attempts to combine them together into a
def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
"""Accepts a list of function signatures and attempts to combine them together into a
new CallableType consisting of the union of all of the given arguments and return types.

If there is at least one non-callable type, return Any (this can happen if there is
Expand All @@ -1507,7 +1507,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
return callables[0]

# Note: we are assuming here that if a user uses some TypeVar 'T' in
# two different overloads, they meant for that TypeVar to mean the
# two different functions, they meant for that TypeVar to mean the
# same thing.
#
# This function will make sure that all instances of that TypeVar 'T'
Expand All @@ -1525,7 +1525,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab

too_complex = False
for target in callables:
# We fall back to Callable[..., Union[<returns>]] if the overloads do not have
# We fall back to Callable[..., Union[<returns>]] if the functions do not have
# the exact same signature. The only exception is if one arg is optional and
# the other is positional: in that case, we continue unioning (and expect a
# positional arg).
Expand Down Expand Up @@ -1820,19 +1820,12 @@ def check_op_reversible(self,
left_expr: Expression,
right_type: Type,
right_expr: Expression,
context: Context) -> Tuple[Type, Type]:
# Note: this kludge exists mostly to maintain compatibility with
# existing error messages. Apparently, if the left-hand-side is a
# union and we have a type mismatch, we print out a special,
# abbreviated error message. (See messages.unsupported_operand_types).
unions_present = isinstance(left_type, UnionType)

context: Context,
msg: MessageBuilder) -> Tuple[Type, Type]:
def make_local_errors() -> MessageBuilder:
"""Creates a new MessageBuilder object."""
local_errors = self.msg.clean_copy()
local_errors = msg.clean_copy()
local_errors.disable_count = 0
if unions_present:
local_errors.disable_type_names += 1
return local_errors

def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
Expand Down Expand Up @@ -2006,30 +1999,101 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
# TODO: Remove this extra case
return result

self.msg.add_errors(errors[0])
msg.add_errors(errors[0])
if len(results) == 1:
return results[0]
else:
error_any = AnyType(TypeOfAny.from_error)
result = error_any, error_any
return result

def check_op(self, method: str, base_type: Type, arg: Expression,
context: Context,
def check_op(self, method: str, base_type: Type,
arg: Expression, context: Context,
allow_reverse: bool = False) -> Tuple[Type, Type]:
"""Type check a binary operation which maps to a method call.

Return tuple (result type, inferred operator method type).
"""

if allow_reverse:
return self.check_op_reversible(
op_name=method,
left_type=base_type,
left_expr=TempNode(base_type),
right_type=self.accept(arg),
right_expr=arg,
context=context)
left_variants = [base_type]
if isinstance(base_type, UnionType):
left_variants = [item for item in base_type.relevant_items()]
right_type = self.accept(arg)

# Step 1: We first try leaving the right arguments alone and destructure
# just the left ones. (Mypy can sometimes perform some more precise inference
# if we leave the right operands a union -- see testOperatorWithEmptyListAndSum.
msg = self.msg.clean_copy()
msg.disable_count = 0
all_results = []
all_inferred = []

for left_possible_type in left_variants:
result, inferred = self.check_op_reversible(
op_name=method,
left_type=left_possible_type,
left_expr=TempNode(left_possible_type),
right_type=right_type,
right_expr=arg,
context=context,
msg=msg)
all_results.append(result)
all_inferred.append(inferred)

if not msg.is_errors():
results_final = UnionType.make_simplified_union(all_results)
inferred_final = UnionType.make_simplified_union(all_inferred)
return results_final, inferred_final

# Step 2: If that fails, we try again but also destructure the right argument.
# This is also necessary to make certain edge cases work -- see
# testOperatorDoubleUnionInterwovenUnionAdd, for example.

# Note: We want to pass in the original 'arg' for 'left_expr' and 'right_expr'
# whenever possible so that plugins and similar things can introspect on the original
# node if possible.
#
# We don't do the same for the base expression because it could lead to weird
# type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
# TODO: Can we use `type_overrides_set()` here?
right_variants = [(right_type, arg)]
if isinstance(right_type, UnionType):
right_variants = [(item, TempNode(item)) for item in right_type.relevant_items()]

msg = self.msg.clean_copy()
msg.disable_count = 0
all_results = []
all_inferred = []

for left_possible_type in left_variants:
for right_possible_type, right_expr in right_variants:
result, inferred = self.check_op_reversible(
op_name=method,
left_type=left_possible_type,
left_expr=TempNode(left_possible_type),
right_type=right_possible_type,
right_expr=right_expr,
context=context,
msg=msg)
all_results.append(result)
all_inferred.append(inferred)

if msg.is_errors():
self.msg.add_errors(msg)
if len(left_variants) >= 2 and len(right_variants) >= 2:
self.msg.warn_both_operands_are_from_unions(context)
elif len(left_variants) >= 2:
self.msg.warn_operand_was_from_union("Left", base_type, context)
elif len(right_variants) >= 2:
self.msg.warn_operand_was_from_union("Right", right_type, context)

# See the comment in 'check_overload_call' for more details on why
# we call 'combine_function_signature' instead of just unioning the inferred
# callable types.
results_final = UnionType.make_simplified_union(all_results)
inferred_final = self.combine_function_signatures(all_inferred)
return results_final, inferred_final
else:
return self.check_op_local_by_name(
method=method,
Expand Down
6 changes: 6 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,12 @@ def overloaded_signatures_ret_specific(self, index: int, context: Context) -> No
self.fail('Overloaded function implementation cannot produce return type '
'of signature {}'.format(index), context)

def warn_both_operands_are_from_unions(self, context: Context) -> None:
self.note('Both left and right operands are unions', context)

def warn_operand_was_from_union(self, side: str, original: Type, context: Context) -> None:
self.note('{} operand is of type {}'.format(side, self.format(original)), context)

def operator_method_signatures_overlap(
self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type,
forward_method: str, context: Context) -> None:
Expand Down
9 changes: 6 additions & 3 deletions test-data/unit/check-callable.test
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ from typing import Callable, Union
x = 5 # type: Union[int, Callable[[], str], Callable[[], int]]

if callable(x):
y = x() + 2 # E: Unsupported operand types for + (likely involving Union)
y = x() + 2 # E: Unsupported operand types for + ("str" and "int") \
# N: Left operand is of type "Union[str, int]"
else:
z = x + 6

Expand All @@ -60,7 +61,8 @@ x = 5 # type: Union[int, str, Callable[[], str]]
if callable(x):
y = x() + 'test'
else:
z = x + 6 # E: Unsupported operand types for + (likely involving Union)
z = x + 6 # E: Unsupported operand types for + ("str" and "int") \
# N: Left operand is of type "Union[int, str]"

[builtins fixtures/callable.pyi]

Expand Down Expand Up @@ -153,7 +155,8 @@ x = 5 # type: Union[int, Callable[[], str]]
if callable(x) and x() == 'test':
x()
else:
x + 5 # E: Unsupported left operand type for + (some union)
x + 5 # E: Unsupported left operand type for + ("Callable[[], str]") \
# N: Left operand is of type "Union[int, Callable[[], str]]"

[builtins fixtures/callable.pyi]

Expand Down
Loading