Skip to content

Commit 9012fc9

Browse files
authored
Some cleanup in partial plugin (#17423)
Fixes #17405 Apart from fixing the crash I fix two obvious bugs I noticed while making this PR.
1 parent cc3492e commit 9012fc9

File tree

3 files changed

+78
-7
lines changed

3 files changed

+78
-7
lines changed

mypy/checkexpr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,8 @@ def apply_function_plugin(
12281228
formal_arg_exprs[formal].append(args[actual])
12291229
if arg_names:
12301230
formal_arg_names[formal].append(arg_names[actual])
1231+
else:
1232+
formal_arg_names[formal].append(None)
12311233
formal_arg_kinds[formal].append(arg_kinds[actual])
12321234

12331235
if object_type is None:

mypy/plugins/functools.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Type,
1818
TypeOfAny,
1919
UnboundType,
20-
UninhabitedType,
2120
get_proper_type,
2221
)
2322

@@ -132,6 +131,9 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
132131
if fn_type is None:
133132
return ctx.default_return_type
134133

134+
# We must normalize from the start to have coherent view together with TypeChecker.
135+
fn_type = fn_type.with_unpacked_kwargs().with_normalized_var_args()
136+
135137
defaulted = fn_type.copy_modified(
136138
arg_kinds=[
137139
(
@@ -146,10 +148,25 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
146148
# Make up a line number if we don't have one
147149
defaulted.set_line(ctx.default_return_type)
148150

149-
actual_args = [a for param in ctx.args[1:] for a in param]
150-
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
151-
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
152-
actual_types = [a for param in ctx.arg_types[1:] for a in param]
151+
# Flatten actual to formal mapping, since this is what check_call() expects.
152+
actual_args = []
153+
actual_arg_kinds = []
154+
actual_arg_names = []
155+
actual_types = []
156+
seen_args = set()
157+
for i, param in enumerate(ctx.args[1:], start=1):
158+
for j, a in enumerate(param):
159+
if a in seen_args:
160+
# Same actual arg can map to multiple formals, but we need to include
161+
# each one only once.
162+
continue
163+
# Here we rely on the fact that expressions are essentially immutable, so
164+
# they can be compared by identity.
165+
seen_args.add(a)
166+
actual_args.append(a)
167+
actual_arg_kinds.append(ctx.arg_kinds[i][j])
168+
actual_arg_names.append(ctx.arg_names[i][j])
169+
actual_types.append(ctx.arg_types[i][j])
153170

154171
# Create a valid context for various ad-hoc inspections in check_call().
155172
call_expr = CallExpr(
@@ -188,7 +205,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
188205
for i, actuals in enumerate(formal_to_actual):
189206
if len(bound.arg_types) == len(fn_type.arg_types):
190207
arg_type = bound.arg_types[i]
191-
if isinstance(get_proper_type(arg_type), UninhabitedType):
208+
if not mypy.checker.is_valid_inferred_type(arg_type):
192209
arg_type = fn_type.arg_types[i] # bit of a hack
193210
else:
194211
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
@@ -210,7 +227,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
210227
partial_names.append(fn_type.arg_names[i])
211228

212229
ret_type = bound.ret_type
213-
if isinstance(get_proper_type(ret_type), UninhabitedType):
230+
if not mypy.checker.is_valid_inferred_type(ret_type):
214231
ret_type = fn_type.ret_type # same kind of hack as above
215232

216233
partially_applied = fn_type.copy_modified(

test-data/unit/check-functools.test

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,55 @@ def foo(cls3: Type[B[T]]):
372372
reveal_type(functools.partial(cls3, 2)()) # N: Revealed type is "__main__.B[T`-1]" \
373373
# E: Argument 1 to "B" has incompatible type "int"; expected "T"
374374
[builtins fixtures/tuple.pyi]
375+
376+
[case testFunctoolsPartialTypedDictUnpack]
377+
from typing_extensions import TypedDict, Unpack
378+
from functools import partial
379+
380+
class Data(TypedDict, total=False):
381+
x: int
382+
383+
def f(**kwargs: Unpack[Data]) -> None: ...
384+
def g(**kwargs: Unpack[Data]) -> None:
385+
partial(f, **kwargs)()
386+
387+
class MoreData(TypedDict, total=False):
388+
x: int
389+
y: int
390+
391+
def f_more(**kwargs: Unpack[MoreData]) -> None: ...
392+
def g_more(**kwargs: Unpack[MoreData]) -> None:
393+
partial(f_more, **kwargs)()
394+
395+
class Good(TypedDict, total=False):
396+
y: int
397+
class Bad(TypedDict, total=False):
398+
y: str
399+
400+
def h(**kwargs: Unpack[Data]) -> None:
401+
bad: Bad
402+
partial(f_more, **kwargs)(**bad) # E: Argument "y" to "f_more" has incompatible type "str"; expected "int"
403+
good: Good
404+
partial(f_more, **kwargs)(**good)
405+
[builtins fixtures/dict.pyi]
406+
407+
[case testFunctoolsPartialNestedGeneric]
408+
from functools import partial
409+
from typing import Generic, TypeVar, List
410+
411+
T = TypeVar("T")
412+
def get(n: int, args: List[T]) -> T: ...
413+
first = partial(get, 0)
414+
415+
x: List[str]
416+
reveal_type(first(x)) # N: Revealed type is "builtins.str"
417+
reveal_type(first([1])) # N: Revealed type is "builtins.int"
418+
419+
first_kw = partial(get, n=0)
420+
reveal_type(first_kw(args=[1])) # N: Revealed type is "builtins.int"
421+
422+
# TODO: this is indeed invalid, but the error is incomprehensible.
423+
first_kw([1]) # E: Too many positional arguments for "get" \
424+
# E: Too few arguments for "get" \
425+
# E: Argument 1 to "get" has incompatible type "List[int]"; expected "int"
426+
[builtins fixtures/list.pyi]

0 commit comments

Comments
 (0)