Skip to content

Commit 3ddc009

Browse files
authored
Support unions in functools.partial (#17284)
Co-authored-by: cdce8p
1 parent 43a605f commit 3ddc009

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

mypy/checker.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from typing_extensions import TypeAlias as _TypeAlias
2727

2828
import mypy.checkexpr
29-
from mypy import errorcodes as codes, message_registry, nodes, operators
29+
from mypy import errorcodes as codes, join, message_registry, nodes, operators
3030
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
3131
from mypy.checkmember import (
3232
MemberContext,
@@ -699,6 +699,21 @@ def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> Callab
699699
)
700700
if isinstance(inner_call, CallableType):
701701
outer_type = inner_call
702+
elif isinstance(inner_type, UnionType):
703+
union_type = make_simplified_union(inner_type.items)
704+
if isinstance(union_type, UnionType):
705+
items = []
706+
for item in union_type.items:
707+
callable_item = self.extract_callable_type(item, ctx)
708+
if callable_item is None:
709+
break
710+
items.append(callable_item)
711+
else:
712+
joined_type = get_proper_type(join.join_type_list(items))
713+
if isinstance(joined_type, CallableType):
714+
outer_type = joined_type
715+
else:
716+
return self.extract_callable_type(union_type, ctx)
702717
if outer_type is None:
703718
self.msg.not_callable(inner_type, ctx)
704719
return outer_type

mypy/join.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import overload
5+
from typing import Sequence, overload
66

77
import mypy.typeops
88
from mypy.maptype import map_instance_to_supertype
@@ -853,7 +853,7 @@ def object_or_any_from_type(typ: ProperType) -> ProperType:
853853
return AnyType(TypeOfAny.implementation_artifact)
854854

855855

856-
def join_type_list(types: list[Type]) -> Type:
856+
def join_type_list(types: Sequence[Type]) -> Type:
857857
if not types:
858858
# This is a little arbitrary but reasonable. Any empty tuple should be compatible
859859
# with all variable length tuples, and this makes it possible.

test-data/unit/check-functools.test

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,25 @@ p(bar, 1, "a", 3.0) # OK
324324
p(bar, 1, "a", 3.0, kwarg="asdf") # OK
325325
p(bar, 1, "a", "b") # E: Argument 1 to "foo" has incompatible type "Callable[[int, str, float], None]"; expected "Callable[[int, str, str], None]"
326326
[builtins fixtures/dict.pyi]
327+
328+
[case testFunctoolsPartialUnion]
329+
import functools
330+
from typing import Any, Callable, Union
331+
332+
cls1: Any
333+
cls2: Union[Any, Any]
334+
reveal_type(functools.partial(cls1, 2)()) # N: Revealed type is "Any"
335+
reveal_type(functools.partial(cls2, 2)()) # N: Revealed type is "Any"
336+
337+
fn1: Union[Callable[[int], int], Callable[[int], int]]
338+
reveal_type(functools.partial(fn1, 2)()) # N: Revealed type is "builtins.int"
339+
340+
fn2: Union[Callable[[int], int], Callable[[int], str]]
341+
reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "builtins.object"
342+
343+
fn3: Union[Callable[[int], int], str]
344+
reveal_type(functools.partial(fn3, 2)()) # E: "str" not callable \
345+
# E: "Union[Callable[[int], int], str]" not callable \
346+
# N: Revealed type is "builtins.int" \
347+
# E: Argument 1 to "partial" has incompatible type "Union[Callable[[int], int], str]"; expected "Callable[..., int]"
348+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)