@@ -595,15 +595,11 @@ def visit_parameters(self, template: Parameters) -> list[Constraint]:
595
595
return self .infer_against_any (template .arg_types , self .actual )
596
596
if type_state .infer_polymorphic and isinstance (self .actual , Parameters ):
597
597
# For polymorphic inference we need to be able to infer secondary constraints
598
- # in situations like [x: T] <: P <: [x: int].
599
- res = []
600
- if len (template .arg_types ) == len (self .actual .arg_types ):
601
- for tt , at in zip (template .arg_types , self .actual .arg_types ):
602
- # This avoids bogus constraints like T <: P.args
603
- if isinstance (at , ParamSpecType ):
604
- continue
605
- res .extend (infer_constraints (tt , at , self .direction ))
606
- return res
598
+ # in situations like [x: T] <: P <: [x: int]. Note we invert direction, since
599
+ # this function expects direction between callables.
600
+ return infer_callable_arguments_constraints (
601
+ template , self .actual , neg_op (self .direction )
602
+ )
607
603
raise RuntimeError ("Parameters cannot be constrained to" )
608
604
609
605
# Non-leaf types
@@ -722,7 +718,8 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
722
718
prefix = mapped_arg .prefix
723
719
if isinstance (instance_arg , Parameters ):
724
720
# No such thing as variance for ParamSpecs, consider them invariant
725
- # TODO: constraints between prefixes
721
+ # TODO: constraints between prefixes using
722
+ # infer_callable_arguments_constraints()
726
723
suffix : Type = instance_arg .copy_modified (
727
724
instance_arg .arg_types [len (prefix .arg_types ) :],
728
725
instance_arg .arg_kinds [len (prefix .arg_kinds ) :],
@@ -793,7 +790,8 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
793
790
prefix = template_arg .prefix
794
791
if isinstance (mapped_arg , Parameters ):
795
792
# No such thing as variance for ParamSpecs, consider them invariant
796
- # TODO: constraints between prefixes
793
+ # TODO: constraints between prefixes using
794
+ # infer_callable_arguments_constraints()
797
795
suffix = mapped_arg .copy_modified (
798
796
mapped_arg .arg_types [len (prefix .arg_types ) :],
799
797
mapped_arg .arg_kinds [len (prefix .arg_kinds ) :],
@@ -962,24 +960,12 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
962
960
unpack_constraints = build_constraints_for_simple_unpack (
963
961
template_types , actual_types , neg_op (self .direction )
964
962
)
965
- template_args = []
966
- cactual_args = []
967
963
res .extend (unpack_constraints )
968
964
else :
969
- template_args = template .arg_types
970
- cactual_args = cactual .arg_types
971
- # TODO: use some more principled "formal to actual" logic
972
- # instead of this lock-step loop over argument types. This identical
973
- # logic should be used in 5 places: in Parameters vs Parameters
974
- # inference, in Instance vs Instance inference for prefixes (two
975
- # branches), and in Callable vs Callable inference (two branches).
976
- for t , a in zip (template_args , cactual_args ):
977
- # This avoids bogus constraints like T <: P.args
978
- if isinstance (a , (ParamSpecType , UnpackType )):
979
- # TODO: can we infer something useful for *T vs P?
980
- continue
981
965
# Negate direction due to function argument type contravariance.
982
- res .extend (infer_constraints (t , a , neg_op (self .direction )))
966
+ res .extend (
967
+ infer_callable_arguments_constraints (template , cactual , self .direction )
968
+ )
983
969
else :
984
970
prefix = param_spec .prefix
985
971
prefix_len = len (prefix .arg_types )
@@ -1028,11 +1014,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
1028
1014
arg_kinds = cactual .arg_kinds [:prefix_len ],
1029
1015
arg_names = cactual .arg_names [:prefix_len ],
1030
1016
)
1031
-
1032
- for t , a in zip (prefix .arg_types , cactual_prefix .arg_types ):
1033
- if isinstance (a , ParamSpecType ):
1034
- continue
1035
- res .extend (infer_constraints (t , a , neg_op (self .direction )))
1017
+ res .extend (
1018
+ infer_callable_arguments_constraints (prefix , cactual_prefix , self .direction )
1019
+ )
1036
1020
1037
1021
template_ret_type , cactual_ret_type = template .ret_type , cactual .ret_type
1038
1022
if template .type_guard is not None :
@@ -1435,3 +1419,89 @@ def build_constraints_for_unpack(
1435
1419
for template_arg , item in zip (template_unpack .items , mapped_middle ):
1436
1420
res .extend (infer_constraints (template_arg , item , direction ))
1437
1421
return res , mapped_prefix + mapped_suffix , template_prefix + template_suffix
1422
+
1423
+
1424
+ def infer_directed_arg_constraints (left : Type , right : Type , direction : int ) -> list [Constraint ]:
1425
+ """Infer constraints between two arguments using direction between original callables."""
1426
+ if isinstance (left , (ParamSpecType , UnpackType )) or isinstance (
1427
+ right , (ParamSpecType , UnpackType )
1428
+ ):
1429
+ # This avoids bogus constraints like T <: P.args
1430
+ # TODO: can we infer something useful for *T vs P?
1431
+ return []
1432
+ if direction == SUBTYPE_OF :
1433
+ # We invert direction to account for argument contravariance.
1434
+ return infer_constraints (left , right , neg_op (direction ))
1435
+ else :
1436
+ return infer_constraints (right , left , neg_op (direction ))
1437
+
1438
+
1439
+ def infer_callable_arguments_constraints (
1440
+ template : CallableType | Parameters , actual : CallableType | Parameters , direction : int
1441
+ ) -> list [Constraint ]:
1442
+ """Infer constraints between argument types of two callables.
1443
+
1444
+ This function essentially extracts four steps from are_parameters_compatible() in
1445
+ subtypes.py that involve subtype checks between argument types. We keep the argument
1446
+ matching logic, but ignore various strictness flags present there, and checks that
1447
+ do not involve subtyping. Then in place of every subtype check we put an infer_constraints()
1448
+ call for the same types.
1449
+ """
1450
+ res = []
1451
+ if direction == SUBTYPE_OF :
1452
+ left , right = template , actual
1453
+ else :
1454
+ left , right = actual , template
1455
+ left_star = left .var_arg ()
1456
+ left_star2 = left .kw_arg ()
1457
+ right_star = right .var_arg ()
1458
+ right_star2 = right .kw_arg ()
1459
+
1460
+ # Numbering of steps below matches the one in are_parameters_compatible() for convenience.
1461
+ # Phase 1a: compare star vs star arguments.
1462
+ if left_star is not None and right_star is not None :
1463
+ res .extend (infer_directed_arg_constraints (left_star .typ , right_star .typ , direction ))
1464
+ if left_star2 is not None and right_star2 is not None :
1465
+ res .extend (infer_directed_arg_constraints (left_star2 .typ , right_star2 .typ , direction ))
1466
+
1467
+ # Phase 1b: compare left args with corresponding non-star right arguments.
1468
+ for right_arg in right .formal_arguments ():
1469
+ left_arg = mypy .typeops .callable_corresponding_argument (left , right_arg )
1470
+ if left_arg is None :
1471
+ continue
1472
+ res .extend (infer_directed_arg_constraints (left_arg .typ , right_arg .typ , direction ))
1473
+
1474
+ # Phase 1c: compare left args with right *args.
1475
+ if right_star is not None :
1476
+ right_by_position = right .try_synthesizing_arg_from_vararg (None )
1477
+ assert right_by_position is not None
1478
+ i = right_star .pos
1479
+ assert i is not None
1480
+ while i < len (left .arg_kinds ) and left .arg_kinds [i ].is_positional ():
1481
+ left_by_position = left .argument_by_position (i )
1482
+ assert left_by_position is not None
1483
+ res .extend (
1484
+ infer_directed_arg_constraints (
1485
+ left_by_position .typ , right_by_position .typ , direction
1486
+ )
1487
+ )
1488
+ i += 1
1489
+
1490
+ # Phase 1d: compare left args with right **kwargs.
1491
+ if right_star2 is not None :
1492
+ right_names = {name for name in right .arg_names if name is not None }
1493
+ left_only_names = set ()
1494
+ for name , kind in zip (left .arg_names , left .arg_kinds ):
1495
+ if name is None or kind .is_star () or name in right_names :
1496
+ continue
1497
+ left_only_names .add (name )
1498
+
1499
+ right_by_name = right .try_synthesizing_arg_from_kwarg (None )
1500
+ assert right_by_name is not None
1501
+ for name in left_only_names :
1502
+ left_by_name = left .argument_by_name (name )
1503
+ assert left_by_name is not None
1504
+ res .extend (
1505
+ infer_directed_arg_constraints (left_by_name .typ , right_by_name .typ , direction )
1506
+ )
1507
+ return res
0 commit comments