Skip to content

Commit c8bae06

Browse files
Add UnionType support when incompatible protocol happens (#10154)
Fix #10129
1 parent 7241b6c commit c8bae06

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

mypy/messages.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,16 @@ def incompatible_argument_note(self,
572572
callee_type: ProperType,
573573
context: Context,
574574
code: Optional[ErrorCode]) -> None:
575-
if (isinstance(original_caller_type, (Instance, TupleType, TypedDictType)) and
576-
isinstance(callee_type, Instance) and callee_type.type.is_protocol):
577-
self.report_protocol_problems(original_caller_type, callee_type, context, code=code)
575+
if isinstance(original_caller_type, (Instance, TupleType, TypedDictType)):
576+
if isinstance(callee_type, Instance) and callee_type.type.is_protocol:
577+
self.report_protocol_problems(original_caller_type, callee_type,
578+
context, code=code)
579+
if isinstance(callee_type, UnionType):
580+
for item in callee_type.items:
581+
item = get_proper_type(item)
582+
if isinstance(item, Instance) and item.type.is_protocol:
583+
self.report_protocol_problems(original_caller_type, item,
584+
context, code=code)
578585
if (isinstance(callee_type, CallableType) and
579586
isinstance(original_caller_type, Instance)):
580587
call = find_member('__call__', original_caller_type, original_caller_type,

test-data/unit/check-protocols.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,6 +2073,38 @@ main:14: note: Got:
20732073
main:14: note: def g(self, x: str) -> None
20742074
main:14: note: <2 more conflict(s) not shown>
20752075

2076+
[case testProtocolIncompatibilityWithUnionType]
2077+
from typing import Any, Optional, Protocol
2078+
2079+
class A(Protocol):
2080+
def execute(self, statement: Any, *args: Any, **kwargs: Any) -> None: ...
2081+
2082+
class B(Protocol):
2083+
def execute(self, stmt: Any, *args: Any, **kwargs: Any) -> None: ...
2084+
def cool(self) -> None: ...
2085+
2086+
def func1(arg: A) -> None: ...
2087+
def func2(arg: Optional[A]) -> None: ...
2088+
2089+
x: B
2090+
func1(x)
2091+
func2(x)
2092+
[builtins fixtures/tuple.pyi]
2093+
[builtins fixtures/dict.pyi]
2094+
[out]
2095+
main:14: error: Argument 1 to "func1" has incompatible type "B"; expected "A"
2096+
main:14: note: Following member(s) of "B" have conflicts:
2097+
main:14: note: Expected:
2098+
main:14: note: def execute(self, statement: Any, *args: Any, **kwargs: Any) -> None
2099+
main:14: note: Got:
2100+
main:14: note: def execute(self, stmt: Any, *args: Any, **kwargs: Any) -> None
2101+
main:15: error: Argument 1 to "func2" has incompatible type "B"; expected "Optional[A]"
2102+
main:15: note: Following member(s) of "B" have conflicts:
2103+
main:15: note: Expected:
2104+
main:15: note: def execute(self, statement: Any, *args: Any, **kwargs: Any) -> None
2105+
main:15: note: Got:
2106+
main:15: note: def execute(self, stmt: Any, *args: Any, **kwargs: Any) -> None
2107+
20762108
[case testDontShowNotesForTupleAndIterableProtocol]
20772109
from typing import Iterable, Sequence, Protocol, NamedTuple
20782110

0 commit comments

Comments
 (0)