diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index 81a3b4d96ef3..3faa1398abd9 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -7,6 +7,7 @@ import mypy.checker import mypy.plugin from mypy.argmap import map_actuals_to_formals +from mypy.expandtype import expand_type from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var from mypy.plugins.common import add_method_to_class from mypy.types import ( @@ -124,11 +125,12 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: return ctx.default_return_type if len(ctx.arg_types[0]) != 1: return ctx.default_return_type + fn = expand_type(get_proper_type(ctx.arg_types[0][0]), {}) - if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded): + if isinstance(fn, Overloaded): # TODO: handle overloads, just fall back to whatever the non-plugin code does return ctx.default_return_type - fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type) + fn_type = ctx.api.extract_callable_type(fn, ctx=ctx.default_return_type) if fn_type is None: return ctx.default_return_type diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test index 5af5dfc8e469..67d11f5fce73 100644 --- a/test-data/unit/check-functools.test +++ b/test-data/unit/check-functools.test @@ -324,3 +324,13 @@ p(bar, 1, "a", 3.0) # OK p(bar, 1, "a", 3.0, kwarg="asdf") # OK p(bar, 1, "a", "b") # E: Argument 1 to "foo" has incompatible type "Callable[[int, str, float], None]"; expected "Callable[[int, str, str], None]" [builtins fixtures/dict.pyi] + +[case testFunctoolsPartialUnion] +import functools +from typing import Any, Union + +cls1: Any +cls2: Union[Any, Any] +functools.partial(cls1, 2) +functools.partial(cls2, 2) +[builtins fixtures/tuple.pyi]