Skip to content

Commit 7f65cc7

Browse files
authored
Infer ParamSpec constraint from arguments (#15896)
Fixes #12278 Fixes #13191 (more tricky nested use cases with optional/keyword args still don't work, but they are quite tricky to fix and may selectively fixed later) This unfortunately requires some special-casing, here is its summary: * If actual argument for `Callable[P, T]` is non-generic and non-lambda, do not put it into inference second pass. * If we are able to infer constraints for `P` without using arguments mapped to `*args: P.args` etc., do not add the constraint for `P` vs those arguments (this applies to both top-level callable constraints, and for nested callable constraints against callables that are known to have imprecise argument kinds). (Btw TODO I added is not related to this PR, I just noticed something obviously wrong)
1 parent f9b1db6 commit 7f65cc7

File tree

8 files changed

+244
-69
lines changed

8 files changed

+244
-69
lines changed

mypy/checkexpr.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,7 +1987,7 @@ def infer_function_type_arguments(
19871987
)
19881988

19891989
arg_pass_nums = self.get_arg_infer_passes(
1990-
callee_type.arg_types, formal_to_actual, len(args)
1990+
callee_type, args, arg_types, formal_to_actual, len(args)
19911991
)
19921992

19931993
pass1_args: list[Type | None] = []
@@ -2001,6 +2001,7 @@ def infer_function_type_arguments(
20012001
callee_type,
20022002
pass1_args,
20032003
arg_kinds,
2004+
arg_names,
20042005
formal_to_actual,
20052006
context=self.argument_infer_context(),
20062007
strict=self.chk.in_checked_function(),
@@ -2061,6 +2062,7 @@ def infer_function_type_arguments(
20612062
callee_type,
20622063
arg_types,
20632064
arg_kinds,
2065+
arg_names,
20642066
formal_to_actual,
20652067
context=self.argument_infer_context(),
20662068
strict=self.chk.in_checked_function(),
@@ -2140,6 +2142,7 @@ def infer_function_type_arguments_pass2(
21402142
callee_type,
21412143
arg_types,
21422144
arg_kinds,
2145+
arg_names,
21432146
formal_to_actual,
21442147
context=self.argument_infer_context(),
21452148
)
@@ -2152,7 +2155,12 @@ def argument_infer_context(self) -> ArgumentInferContext:
21522155
)
21532156

21542157
def get_arg_infer_passes(
2155-
self, arg_types: list[Type], formal_to_actual: list[list[int]], num_actuals: int
2158+
self,
2159+
callee: CallableType,
2160+
args: list[Expression],
2161+
arg_types: list[Type],
2162+
formal_to_actual: list[list[int]],
2163+
num_actuals: int,
21562164
) -> list[int]:
21572165
"""Return pass numbers for args for two-pass argument type inference.
21582166
@@ -2163,8 +2171,28 @@ def get_arg_infer_passes(
21632171
lambdas more effectively.
21642172
"""
21652173
res = [1] * num_actuals
2166-
for i, arg in enumerate(arg_types):
2167-
if arg.accept(ArgInferSecondPassQuery()):
2174+
for i, arg in enumerate(callee.arg_types):
2175+
skip_param_spec = False
2176+
p_formal = get_proper_type(callee.arg_types[i])
2177+
if isinstance(p_formal, CallableType) and p_formal.param_spec():
2178+
for j in formal_to_actual[i]:
2179+
p_actual = get_proper_type(arg_types[j])
2180+
# This is an exception from the usual logic where we put generic Callable
2181+
# arguments in the second pass. If we have a non-generic actual, it is
2182+
# likely to infer good constraints, for example if we have:
2183+
# def run(Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
2184+
# def test(x: int, y: int) -> int: ...
2185+
# run(test, 1, 2)
2186+
# we will use `test` for inference, since it will allow to infer also
2187+
# argument *names* for P <: [x: int, y: int].
2188+
if (
2189+
isinstance(p_actual, CallableType)
2190+
and not p_actual.variables
2191+
and not isinstance(args[j], LambdaExpr)
2192+
):
2193+
skip_param_spec = True
2194+
break
2195+
if not skip_param_spec and arg.accept(ArgInferSecondPassQuery()):
21682196
for j in formal_to_actual[i]:
21692197
res[j] = 2
21702198
return res
@@ -4903,7 +4931,9 @@ def infer_lambda_type_using_context(
49034931
self.chk.fail(message_registry.CANNOT_INFER_LAMBDA_TYPE, e)
49044932
return None, None
49054933

4906-
return callable_ctx, callable_ctx
4934+
# Type of lambda must have correct argument names, to prevent false
4935+
# negatives when lambdas appear in `ParamSpec` context.
4936+
return callable_ctx.copy_modified(arg_names=e.arg_names), callable_ctx
49074937

49084938
def visit_super_expr(self, e: SuperExpr) -> Type:
49094939
"""Type check a super expression (non-lvalue)."""
@@ -5921,6 +5951,7 @@ def __init__(self) -> None:
59215951
super().__init__(types.ANY_STRATEGY)
59225952

59235953
def visit_callable_type(self, t: CallableType) -> bool:
5954+
# TODO: we need to check only for type variables of original callable.
59245955
return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery())
59255956

59265957

mypy/constraints.py

Lines changed: 97 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def infer_constraints_for_callable(
108108
callee: CallableType,
109109
arg_types: Sequence[Type | None],
110110
arg_kinds: list[ArgKind],
111+
arg_names: Sequence[str | None] | None,
111112
formal_to_actual: list[list[int]],
112113
context: ArgumentInferContext,
113114
) -> list[Constraint]:
@@ -118,6 +119,20 @@ def infer_constraints_for_callable(
118119
constraints: list[Constraint] = []
119120
mapper = ArgTypeExpander(context)
120121

122+
param_spec = callee.param_spec()
123+
param_spec_arg_types = []
124+
param_spec_arg_names = []
125+
param_spec_arg_kinds = []
126+
127+
incomplete_star_mapping = False
128+
for i, actuals in enumerate(formal_to_actual):
129+
for actual in actuals:
130+
if actual is None and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
131+
# We can't use arguments to infer ParamSpec constraint, if only some
132+
# are present in the current inference pass.
133+
incomplete_star_mapping = True
134+
break
135+
121136
for i, actuals in enumerate(formal_to_actual):
122137
if isinstance(callee.arg_types[i], UnpackType):
123138
unpack_type = callee.arg_types[i]
@@ -194,11 +209,47 @@ def infer_constraints_for_callable(
194209
actual_type = mapper.expand_actual_type(
195210
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
196211
)
197-
# TODO: if callee has ParamSpec, we need to collect all actuals that map to star
198-
# args and create single constraint between P and resulting Parameters instead.
199-
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
200-
constraints.extend(c)
201-
212+
if (
213+
param_spec
214+
and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2)
215+
and not incomplete_star_mapping
216+
):
217+
# If actual arguments are mapped to ParamSpec type, we can't infer individual
218+
# constraints, instead store them and infer single constraint at the end.
219+
# It is impossible to map actual kind to formal kind, so use some heuristic.
220+
# This inference is used as a fallback, so relying on heuristic should be OK.
221+
param_spec_arg_types.append(
222+
mapper.expand_actual_type(
223+
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
224+
)
225+
)
226+
actual_kind = arg_kinds[actual]
227+
param_spec_arg_kinds.append(
228+
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
229+
)
230+
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
231+
else:
232+
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
233+
constraints.extend(c)
234+
if (
235+
param_spec
236+
and not any(c.type_var == param_spec.id for c in constraints)
237+
and not incomplete_star_mapping
238+
):
239+
# Use ParamSpec constraint from arguments only if there are no other constraints,
240+
# since as explained above it is quite ad-hoc.
241+
constraints.append(
242+
Constraint(
243+
param_spec,
244+
SUPERTYPE_OF,
245+
Parameters(
246+
arg_types=param_spec_arg_types,
247+
arg_kinds=param_spec_arg_kinds,
248+
arg_names=param_spec_arg_names,
249+
imprecise_arg_kinds=True,
250+
),
251+
)
252+
)
202253
return constraints
203254

204255

@@ -949,6 +1000,14 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
9491000
res: list[Constraint] = []
9501001
cactual = self.actual.with_unpacked_kwargs()
9511002
param_spec = template.param_spec()
1003+
1004+
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
1005+
if template.type_guard is not None:
1006+
template_ret_type = template.type_guard
1007+
if cactual.type_guard is not None:
1008+
cactual_ret_type = cactual.type_guard
1009+
res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
1010+
9521011
if param_spec is None:
9531012
# TODO: Erase template variables if it is generic?
9541013
if (
@@ -1008,51 +1067,50 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
10081067
)
10091068
extra_tvars = True
10101069

1070+
# Compare prefixes as well
1071+
cactual_prefix = cactual.copy_modified(
1072+
arg_types=cactual.arg_types[:prefix_len],
1073+
arg_kinds=cactual.arg_kinds[:prefix_len],
1074+
arg_names=cactual.arg_names[:prefix_len],
1075+
)
1076+
res.extend(
1077+
infer_callable_arguments_constraints(prefix, cactual_prefix, self.direction)
1078+
)
1079+
1080+
param_spec_target: Type | None = None
1081+
skip_imprecise = (
1082+
any(c.type_var == param_spec.id for c in res) and cactual.imprecise_arg_kinds
1083+
)
10111084
if not cactual_ps:
10121085
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
10131086
prefix_len = min(prefix_len, max_prefix_len)
1014-
res.append(
1015-
Constraint(
1016-
param_spec,
1017-
neg_op(self.direction),
1018-
Parameters(
1019-
arg_types=cactual.arg_types[prefix_len:],
1020-
arg_kinds=cactual.arg_kinds[prefix_len:],
1021-
arg_names=cactual.arg_names[prefix_len:],
1022-
variables=cactual.variables
1023-
if not type_state.infer_polymorphic
1024-
else [],
1025-
),
1087+
# This logic matches top-level callable constraint exception, if we managed
1088+
# to get other constraints for ParamSpec, don't infer one with imprecise kinds
1089+
if not skip_imprecise:
1090+
param_spec_target = Parameters(
1091+
arg_types=cactual.arg_types[prefix_len:],
1092+
arg_kinds=cactual.arg_kinds[prefix_len:],
1093+
arg_names=cactual.arg_names[prefix_len:],
1094+
variables=cactual.variables
1095+
if not type_state.infer_polymorphic
1096+
else [],
1097+
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
10261098
)
1027-
)
10281099
else:
1029-
if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types):
1030-
cactual_ps = cactual_ps.copy_modified(
1100+
if (
1101+
len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types)
1102+
and not skip_imprecise
1103+
):
1104+
param_spec_target = cactual_ps.copy_modified(
10311105
prefix=Parameters(
10321106
arg_types=cactual_ps.prefix.arg_types[prefix_len:],
10331107
arg_kinds=cactual_ps.prefix.arg_kinds[prefix_len:],
10341108
arg_names=cactual_ps.prefix.arg_names[prefix_len:],
1109+
imprecise_arg_kinds=cactual_ps.prefix.imprecise_arg_kinds,
10351110
)
10361111
)
1037-
res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps))
1038-
1039-
# Compare prefixes as well
1040-
cactual_prefix = cactual.copy_modified(
1041-
arg_types=cactual.arg_types[:prefix_len],
1042-
arg_kinds=cactual.arg_kinds[:prefix_len],
1043-
arg_names=cactual.arg_names[:prefix_len],
1044-
)
1045-
res.extend(
1046-
infer_callable_arguments_constraints(prefix, cactual_prefix, self.direction)
1047-
)
1048-
1049-
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
1050-
if template.type_guard is not None:
1051-
template_ret_type = template.type_guard
1052-
if cactual.type_guard is not None:
1053-
cactual_ret_type = cactual.type_guard
1054-
1055-
res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
1112+
if param_spec_target is not None:
1113+
res.append(Constraint(param_spec, neg_op(self.direction), param_spec_target))
10561114
if extra_tvars:
10571115
for c in res:
10581116
c.extra_tvars += cactual.variables

mypy/expandtype.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
336336
arg_types=self.expand_types(t.arg_types),
337337
ret_type=t.ret_type.accept(self),
338338
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
339+
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
339340
)
340341
elif isinstance(repl, ParamSpecType):
341342
# We're substituting one ParamSpec for another; this can mean that the prefix
@@ -352,6 +353,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
352353
arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
353354
ret_type=t.ret_type.accept(self),
354355
from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types),
356+
imprecise_arg_kinds=(t.imprecise_arg_kinds or prefix.imprecise_arg_kinds),
355357
)
356358

357359
var_arg = t.var_arg()

mypy/infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def infer_function_type_arguments(
3333
callee_type: CallableType,
3434
arg_types: Sequence[Type | None],
3535
arg_kinds: list[ArgKind],
36+
arg_names: Sequence[str | None] | None,
3637
formal_to_actual: list[list[int]],
3738
context: ArgumentInferContext,
3839
strict: bool = True,
@@ -53,7 +54,7 @@ def infer_function_type_arguments(
5354
"""
5455
# Infer constraints.
5556
constraints = infer_constraints_for_callable(
56-
callee_type, arg_types, arg_kinds, formal_to_actual, context
57+
callee_type, arg_types, arg_kinds, arg_names, formal_to_actual, context
5758
)
5859

5960
# Solve constraints.

0 commit comments

Comments
 (0)