Skip to content

Commit 7141d6b

Browse files
authored
More principled approach for callable vs callable inference (#15910)
Fixes #702 (one of the oldest open issues) The approach is quite simple, I essentially replicate the logic from subtyping check, while replacing each `is_subtype()` call with `infer_constraints()` call. Note that we don't have various options available in `constraints.py` so I use all checks, even those that may be skipped with some strictness flags (so we can infer as many constraints as possible). Depending on the output of `mypy_primer` we can try to tune this. Note that while I was looking at subtyping code, I noticed couple inconsistencies for ParamSpecs, I added TODOs for them (and updated some existing TODOs). I also deleted some code that should be dead code after my previous cleanup. Among inconsistencies most notably, subtyping between `Parameters` uses wrong (opposite) direction. Normally, `Parameters` entity behaves covariantly (w.r.t. types of individual arguments) as a single big argument, like a tuple plus a map. But then this entity appears in a contravariant position in `Callable`. This is how we handle it in `constraints.py`, `join.py`, `meet.py` etc. I tried to fix the left/right order in `visit_parameters()`, but then one test failed (and btw same test would also fail if I would try to fix variance in `visit_instance()`). I decided to leave this for separate PR(s).
1 parent e804e8d commit 7141d6b

File tree

5 files changed

+260
-55
lines changed

5 files changed

+260
-55
lines changed

mypy/constraints.py

Lines changed: 101 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -595,15 +595,11 @@ def visit_parameters(self, template: Parameters) -> list[Constraint]:
595595
return self.infer_against_any(template.arg_types, self.actual)
596596
if type_state.infer_polymorphic and isinstance(self.actual, Parameters):
597597
# For polymorphic inference we need to be able to infer secondary constraints
598-
# in situations like [x: T] <: P <: [x: int].
599-
res = []
600-
if len(template.arg_types) == len(self.actual.arg_types):
601-
for tt, at in zip(template.arg_types, self.actual.arg_types):
602-
# This avoids bogus constraints like T <: P.args
603-
if isinstance(at, ParamSpecType):
604-
continue
605-
res.extend(infer_constraints(tt, at, self.direction))
606-
return res
598+
# in situations like [x: T] <: P <: [x: int]. Note we invert direction, since
599+
# this function expects direction between callables.
600+
return infer_callable_arguments_constraints(
601+
template, self.actual, neg_op(self.direction)
602+
)
607603
raise RuntimeError("Parameters cannot be constrained to")
608604

609605
# Non-leaf types
@@ -722,7 +718,8 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
722718
prefix = mapped_arg.prefix
723719
if isinstance(instance_arg, Parameters):
724720
# No such thing as variance for ParamSpecs, consider them invariant
725-
# TODO: constraints between prefixes
721+
# TODO: constraints between prefixes using
722+
# infer_callable_arguments_constraints()
726723
suffix: Type = instance_arg.copy_modified(
727724
instance_arg.arg_types[len(prefix.arg_types) :],
728725
instance_arg.arg_kinds[len(prefix.arg_kinds) :],
@@ -793,7 +790,8 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
793790
prefix = template_arg.prefix
794791
if isinstance(mapped_arg, Parameters):
795792
# No such thing as variance for ParamSpecs, consider them invariant
796-
# TODO: constraints between prefixes
793+
# TODO: constraints between prefixes using
794+
# infer_callable_arguments_constraints()
797795
suffix = mapped_arg.copy_modified(
798796
mapped_arg.arg_types[len(prefix.arg_types) :],
799797
mapped_arg.arg_kinds[len(prefix.arg_kinds) :],
@@ -962,24 +960,12 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
962960
unpack_constraints = build_constraints_for_simple_unpack(
963961
template_types, actual_types, neg_op(self.direction)
964962
)
965-
template_args = []
966-
cactual_args = []
967963
res.extend(unpack_constraints)
968964
else:
969-
template_args = template.arg_types
970-
cactual_args = cactual.arg_types
971-
# TODO: use some more principled "formal to actual" logic
972-
# instead of this lock-step loop over argument types. This identical
973-
# logic should be used in 5 places: in Parameters vs Parameters
974-
# inference, in Instance vs Instance inference for prefixes (two
975-
# branches), and in Callable vs Callable inference (two branches).
976-
for t, a in zip(template_args, cactual_args):
977-
# This avoids bogus constraints like T <: P.args
978-
if isinstance(a, (ParamSpecType, UnpackType)):
979-
# TODO: can we infer something useful for *T vs P?
980-
continue
981965
# Negate direction due to function argument type contravariance.
982-
res.extend(infer_constraints(t, a, neg_op(self.direction)))
966+
res.extend(
967+
infer_callable_arguments_constraints(template, cactual, self.direction)
968+
)
983969
else:
984970
prefix = param_spec.prefix
985971
prefix_len = len(prefix.arg_types)
@@ -1028,11 +1014,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
10281014
arg_kinds=cactual.arg_kinds[:prefix_len],
10291015
arg_names=cactual.arg_names[:prefix_len],
10301016
)
1031-
1032-
for t, a in zip(prefix.arg_types, cactual_prefix.arg_types):
1033-
if isinstance(a, ParamSpecType):
1034-
continue
1035-
res.extend(infer_constraints(t, a, neg_op(self.direction)))
1017+
res.extend(
1018+
infer_callable_arguments_constraints(prefix, cactual_prefix, self.direction)
1019+
)
10361020

10371021
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
10381022
if template.type_guard is not None:
@@ -1435,3 +1419,89 @@ def build_constraints_for_unpack(
14351419
for template_arg, item in zip(template_unpack.items, mapped_middle):
14361420
res.extend(infer_constraints(template_arg, item, direction))
14371421
return res, mapped_prefix + mapped_suffix, template_prefix + template_suffix
1422+
1423+
1424+
def infer_directed_arg_constraints(left: Type, right: Type, direction: int) -> list[Constraint]:
1425+
"""Infer constraints between two arguments using direction between original callables."""
1426+
if isinstance(left, (ParamSpecType, UnpackType)) or isinstance(
1427+
right, (ParamSpecType, UnpackType)
1428+
):
1429+
# This avoids bogus constraints like T <: P.args
1430+
# TODO: can we infer something useful for *T vs P?
1431+
return []
1432+
if direction == SUBTYPE_OF:
1433+
# We invert direction to account for argument contravariance.
1434+
return infer_constraints(left, right, neg_op(direction))
1435+
else:
1436+
return infer_constraints(right, left, neg_op(direction))
1437+
1438+
1439+
def infer_callable_arguments_constraints(
1440+
template: CallableType | Parameters, actual: CallableType | Parameters, direction: int
1441+
) -> list[Constraint]:
1442+
"""Infer constraints between argument types of two callables.
1443+
1444+
This function essentially extracts four steps from are_parameters_compatible() in
1445+
subtypes.py that involve subtype checks between argument types. We keep the argument
1446+
matching logic, but ignore various strictness flags present there, and checks that
1447+
do not involve subtyping. Then in place of every subtype check we put an infer_constraints()
1448+
call for the same types.
1449+
"""
1450+
res = []
1451+
if direction == SUBTYPE_OF:
1452+
left, right = template, actual
1453+
else:
1454+
left, right = actual, template
1455+
left_star = left.var_arg()
1456+
left_star2 = left.kw_arg()
1457+
right_star = right.var_arg()
1458+
right_star2 = right.kw_arg()
1459+
1460+
# Numbering of steps below matches the one in are_parameters_compatible() for convenience.
1461+
# Phase 1a: compare star vs star arguments.
1462+
if left_star is not None and right_star is not None:
1463+
res.extend(infer_directed_arg_constraints(left_star.typ, right_star.typ, direction))
1464+
if left_star2 is not None and right_star2 is not None:
1465+
res.extend(infer_directed_arg_constraints(left_star2.typ, right_star2.typ, direction))
1466+
1467+
# Phase 1b: compare left args with corresponding non-star right arguments.
1468+
for right_arg in right.formal_arguments():
1469+
left_arg = mypy.typeops.callable_corresponding_argument(left, right_arg)
1470+
if left_arg is None:
1471+
continue
1472+
res.extend(infer_directed_arg_constraints(left_arg.typ, right_arg.typ, direction))
1473+
1474+
# Phase 1c: compare left args with right *args.
1475+
if right_star is not None:
1476+
right_by_position = right.try_synthesizing_arg_from_vararg(None)
1477+
assert right_by_position is not None
1478+
i = right_star.pos
1479+
assert i is not None
1480+
while i < len(left.arg_kinds) and left.arg_kinds[i].is_positional():
1481+
left_by_position = left.argument_by_position(i)
1482+
assert left_by_position is not None
1483+
res.extend(
1484+
infer_directed_arg_constraints(
1485+
left_by_position.typ, right_by_position.typ, direction
1486+
)
1487+
)
1488+
i += 1
1489+
1490+
# Phase 1d: compare left args with right **kwargs.
1491+
if right_star2 is not None:
1492+
right_names = {name for name in right.arg_names if name is not None}
1493+
left_only_names = set()
1494+
for name, kind in zip(left.arg_names, left.arg_kinds):
1495+
if name is None or kind.is_star() or name in right_names:
1496+
continue
1497+
left_only_names.add(name)
1498+
1499+
right_by_name = right.try_synthesizing_arg_from_kwarg(None)
1500+
assert right_by_name is not None
1501+
for name in left_only_names:
1502+
left_by_name = left.argument_by_name(name)
1503+
assert left_by_name is not None
1504+
res.extend(
1505+
infer_directed_arg_constraints(left_by_name.typ, right_by_name.typ, direction)
1506+
)
1507+
return res

mypy/subtypes.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ def check_mixed(
590590
):
591591
nominal = False
592592
else:
593+
# TODO: everywhere else ParamSpecs are handled as invariant.
593594
if not check_type_parameter(
594595
lefta, righta, COVARIANT, self.proper_subtype, self.subtype_context
595596
):
@@ -666,13 +667,12 @@ def visit_unpack_type(self, left: UnpackType) -> bool:
666667
return False
667668

668669
def visit_parameters(self, left: Parameters) -> bool:
669-
if isinstance(self.right, (Parameters, CallableType)):
670-
right = self.right
671-
if isinstance(right, CallableType):
672-
right = right.with_unpacked_kwargs()
670+
if isinstance(self.right, Parameters):
671+
# TODO: direction here should be opposite, this function expects
672+
# order of callables, while parameters are contravariant.
673673
return are_parameters_compatible(
674674
left,
675-
right,
675+
self.right,
676676
is_compat=self._is_subtype,
677677
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
678678
)
@@ -723,14 +723,6 @@ def visit_callable_type(self, left: CallableType) -> bool:
723723
elif isinstance(right, TypeType):
724724
# This is unsound, we don't check the __init__ signature.
725725
return left.is_type_obj() and self._is_subtype(left.ret_type, right.item)
726-
elif isinstance(right, Parameters):
727-
# this doesn't check return types.... but is needed for is_equivalent
728-
return are_parameters_compatible(
729-
left.with_unpacked_kwargs(),
730-
right,
731-
is_compat=self._is_subtype,
732-
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
733-
)
734726
else:
735727
return False
736728

@@ -1456,7 +1448,6 @@ def g(x: int) -> int: ...
14561448
right,
14571449
is_compat=is_compat,
14581450
ignore_pos_arg_names=ignore_pos_arg_names,
1459-
check_args_covariantly=check_args_covariantly,
14601451
allow_partial_overlap=allow_partial_overlap,
14611452
strict_concatenate_check=strict_concatenate_check,
14621453
)
@@ -1480,7 +1471,6 @@ def are_parameters_compatible(
14801471
*,
14811472
is_compat: Callable[[Type, Type], bool],
14821473
ignore_pos_arg_names: bool = False,
1483-
check_args_covariantly: bool = False,
14841474
allow_partial_overlap: bool = False,
14851475
strict_concatenate_check: bool = False,
14861476
) -> bool:
@@ -1534,7 +1524,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
15341524

15351525
# Phase 1b: Check non-star args: for every arg right can accept, left must
15361526
# also accept. The only exception is if we are allowing partial
1537-
# partial overlaps: in that case, we ignore optional args on the right.
1527+
# overlaps: in that case, we ignore optional args on the right.
15381528
for right_arg in right.formal_arguments():
15391529
left_arg = mypy.typeops.callable_corresponding_argument(left, right_arg)
15401530
if left_arg is None:
@@ -1548,7 +1538,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
15481538

15491539
# Phase 1c: Check var args. Right has an infinite series of optional positional
15501540
# arguments. Get all further positional args of left, and make sure
1551-
# they're more general then the corresponding member in right.
1541+
# they're more general than the corresponding member in right.
15521542
if right_star is not None:
15531543
# Synthesize an anonymous formal argument for the right
15541544
right_by_position = right.try_synthesizing_arg_from_vararg(None)
@@ -1575,7 +1565,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
15751565

15761566
# Phase 1d: Check kw args. Right has an infinite series of optional named
15771567
# arguments. Get all further named args of left, and make sure
1578-
# they're more general then the corresponding member in right.
1568+
# they're more general than the corresponding member in right.
15791569
if right_star2 is not None:
15801570
right_names = {name for name in right.arg_names if name is not None}
15811571
left_only_names = set()
@@ -1643,6 +1633,10 @@ def are_args_compatible(
16431633
allow_partial_overlap: bool,
16441634
is_compat: Callable[[Type, Type], bool],
16451635
) -> bool:
1636+
if left.required and right.required:
1637+
# If both arguments are required allow_partial_overlap has no effect.
1638+
allow_partial_overlap = False
1639+
16461640
def is_different(left_item: object | None, right_item: object | None) -> bool:
16471641
"""Checks if the left and right items are different.
16481642
@@ -1670,7 +1664,7 @@ def is_different(left_item: object | None, right_item: object | None) -> bool:
16701664

16711665
# If right's argument is optional, left's must also be
16721666
# (unless we're relaxing the checks to allow potential
1673-
# rather then definite compatibility).
1667+
# rather than definite compatibility).
16741668
if not allow_partial_overlap and not right.required and left.required:
16751669
return False
16761670

mypy/types.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,9 +1545,6 @@ class FormalArgument(NamedTuple):
15451545
required: bool
15461546

15471547

1548-
# TODO: should this take bound typevars too? what would this take?
1549-
# ex: class Z(Generic[P, T]): ...; Z[[V], V]
1550-
# What does a typevar even mean in this context?
15511548
class Parameters(ProperType):
15521549
"""Type that represents the parameters to a function.
15531550
@@ -1559,6 +1556,8 @@ class Parameters(ProperType):
15591556
"arg_names",
15601557
"min_args",
15611558
"is_ellipsis_args",
1559+
# TODO: variables don't really belong here, but they are used to allow hacky support
1560+
# for forall . Foo[[x: T], T] by capturing generic callable with ParamSpec, see #15909
15621561
"variables",
15631562
)
15641563

@@ -1602,7 +1601,7 @@ def copy_modified(
16021601
variables=variables if variables is not _dummy else self.variables,
16031602
)
16041603

1605-
# the following are copied from CallableType. Is there a way to decrease code duplication?
1604+
# TODO: here is a lot of code duplication with Callable type, fix this.
16061605
def var_arg(self) -> FormalArgument | None:
16071606
"""The formal argument for *args."""
16081607
for position, (type, kind) in enumerate(zip(self.arg_types, self.arg_kinds)):
@@ -2046,7 +2045,6 @@ def param_spec(self) -> ParamSpecType | None:
20462045
return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix)
20472046

20482047
def expand_param_spec(self, c: Parameters) -> CallableType:
2049-
# TODO: try deleting variables from Parameters after new type inference is default.
20502048
variables = c.variables
20512049
return self.copy_modified(
20522050
arg_types=self.arg_types[:-2] + c.arg_types,

0 commit comments

Comments
 (0)