@@ -108,6 +108,7 @@ def infer_constraints_for_callable(
108
108
callee : CallableType ,
109
109
arg_types : Sequence [Type | None ],
110
110
arg_kinds : list [ArgKind ],
111
+ arg_names : Sequence [str | None ] | None ,
111
112
formal_to_actual : list [list [int ]],
112
113
context : ArgumentInferContext ,
113
114
) -> list [Constraint ]:
@@ -118,6 +119,20 @@ def infer_constraints_for_callable(
118
119
constraints : list [Constraint ] = []
119
120
mapper = ArgTypeExpander (context )
120
121
122
+ param_spec = callee .param_spec ()
123
+ param_spec_arg_types = []
124
+ param_spec_arg_names = []
125
+ param_spec_arg_kinds = []
126
+
127
+ incomplete_star_mapping = False
128
+ for i , actuals in enumerate (formal_to_actual ):
129
+ for actual in actuals :
130
+ if actual is None and callee .arg_kinds [i ] in (ARG_STAR , ARG_STAR2 ):
131
+ # We can't use arguments to infer ParamSpec constraint, if only some
132
+ # are present in the current inference pass.
133
+ incomplete_star_mapping = True
134
+ break
135
+
121
136
for i , actuals in enumerate (formal_to_actual ):
122
137
if isinstance (callee .arg_types [i ], UnpackType ):
123
138
unpack_type = callee .arg_types [i ]
@@ -194,11 +209,47 @@ def infer_constraints_for_callable(
194
209
actual_type = mapper .expand_actual_type (
195
210
actual_arg_type , arg_kinds [actual ], callee .arg_names [i ], callee .arg_kinds [i ]
196
211
)
197
- # TODO: if callee has ParamSpec, we need to collect all actuals that map to star
198
- # args and create single constraint between P and resulting Parameters instead.
199
- c = infer_constraints (callee .arg_types [i ], actual_type , SUPERTYPE_OF )
200
- constraints .extend (c )
201
-
212
+ if (
213
+ param_spec
214
+ and callee .arg_kinds [i ] in (ARG_STAR , ARG_STAR2 )
215
+ and not incomplete_star_mapping
216
+ ):
217
+ # If actual arguments are mapped to ParamSpec type, we can't infer individual
218
+ # constraints, instead store them and infer single constraint at the end.
219
+ # It is impossible to map actual kind to formal kind, so use some heuristic.
220
+ # This inference is used as a fallback, so relying on heuristic should be OK.
221
+ param_spec_arg_types .append (
222
+ mapper .expand_actual_type (
223
+ actual_arg_type , arg_kinds [actual ], None , arg_kinds [actual ]
224
+ )
225
+ )
226
+ actual_kind = arg_kinds [actual ]
227
+ param_spec_arg_kinds .append (
228
+ ARG_POS if actual_kind not in (ARG_STAR , ARG_STAR2 ) else actual_kind
229
+ )
230
+ param_spec_arg_names .append (arg_names [actual ] if arg_names else None )
231
+ else :
232
+ c = infer_constraints (callee .arg_types [i ], actual_type , SUPERTYPE_OF )
233
+ constraints .extend (c )
234
+ if (
235
+ param_spec
236
+ and not any (c .type_var == param_spec .id for c in constraints )
237
+ and not incomplete_star_mapping
238
+ ):
239
+ # Use ParamSpec constraint from arguments only if there are no other constraints,
240
+ # since as explained above it is quite ad-hoc.
241
+ constraints .append (
242
+ Constraint (
243
+ param_spec ,
244
+ SUPERTYPE_OF ,
245
+ Parameters (
246
+ arg_types = param_spec_arg_types ,
247
+ arg_kinds = param_spec_arg_kinds ,
248
+ arg_names = param_spec_arg_names ,
249
+ imprecise_arg_kinds = True ,
250
+ ),
251
+ )
252
+ )
202
253
return constraints
203
254
204
255
@@ -949,6 +1000,14 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
949
1000
res : list [Constraint ] = []
950
1001
cactual = self .actual .with_unpacked_kwargs ()
951
1002
param_spec = template .param_spec ()
1003
+
1004
+ template_ret_type , cactual_ret_type = template .ret_type , cactual .ret_type
1005
+ if template .type_guard is not None :
1006
+ template_ret_type = template .type_guard
1007
+ if cactual .type_guard is not None :
1008
+ cactual_ret_type = cactual .type_guard
1009
+ res .extend (infer_constraints (template_ret_type , cactual_ret_type , self .direction ))
1010
+
952
1011
if param_spec is None :
953
1012
# TODO: Erase template variables if it is generic?
954
1013
if (
@@ -1008,51 +1067,50 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
1008
1067
)
1009
1068
extra_tvars = True
1010
1069
1070
+ # Compare prefixes as well
1071
+ cactual_prefix = cactual .copy_modified (
1072
+ arg_types = cactual .arg_types [:prefix_len ],
1073
+ arg_kinds = cactual .arg_kinds [:prefix_len ],
1074
+ arg_names = cactual .arg_names [:prefix_len ],
1075
+ )
1076
+ res .extend (
1077
+ infer_callable_arguments_constraints (prefix , cactual_prefix , self .direction )
1078
+ )
1079
+
1080
+ param_spec_target : Type | None = None
1081
+ skip_imprecise = (
1082
+ any (c .type_var == param_spec .id for c in res ) and cactual .imprecise_arg_kinds
1083
+ )
1011
1084
if not cactual_ps :
1012
1085
max_prefix_len = len ([k for k in cactual .arg_kinds if k in (ARG_POS , ARG_OPT )])
1013
1086
prefix_len = min (prefix_len , max_prefix_len )
1014
- res .append (
1015
- Constraint (
1016
- param_spec ,
1017
- neg_op (self .direction ),
1018
- Parameters (
1019
- arg_types = cactual .arg_types [prefix_len :],
1020
- arg_kinds = cactual .arg_kinds [prefix_len :],
1021
- arg_names = cactual .arg_names [prefix_len :],
1022
- variables = cactual .variables
1023
- if not type_state .infer_polymorphic
1024
- else [],
1025
- ),
1087
+ # This logic matches top-level callable constraint exception, if we managed
1088
+ # to get other constraints for ParamSpec, don't infer one with imprecise kinds
1089
+ if not skip_imprecise :
1090
+ param_spec_target = Parameters (
1091
+ arg_types = cactual .arg_types [prefix_len :],
1092
+ arg_kinds = cactual .arg_kinds [prefix_len :],
1093
+ arg_names = cactual .arg_names [prefix_len :],
1094
+ variables = cactual .variables
1095
+ if not type_state .infer_polymorphic
1096
+ else [],
1097
+ imprecise_arg_kinds = cactual .imprecise_arg_kinds ,
1026
1098
)
1027
- )
1028
1099
else :
1029
- if len (param_spec .prefix .arg_types ) <= len (cactual_ps .prefix .arg_types ):
1030
- cactual_ps = cactual_ps .copy_modified (
1100
+ if (
1101
+ len (param_spec .prefix .arg_types ) <= len (cactual_ps .prefix .arg_types )
1102
+ and not skip_imprecise
1103
+ ):
1104
+ param_spec_target = cactual_ps .copy_modified (
1031
1105
prefix = Parameters (
1032
1106
arg_types = cactual_ps .prefix .arg_types [prefix_len :],
1033
1107
arg_kinds = cactual_ps .prefix .arg_kinds [prefix_len :],
1034
1108
arg_names = cactual_ps .prefix .arg_names [prefix_len :],
1109
+ imprecise_arg_kinds = cactual_ps .prefix .imprecise_arg_kinds ,
1035
1110
)
1036
1111
)
1037
- res .append (Constraint (param_spec , neg_op (self .direction ), cactual_ps ))
1038
-
1039
- # Compare prefixes as well
1040
- cactual_prefix = cactual .copy_modified (
1041
- arg_types = cactual .arg_types [:prefix_len ],
1042
- arg_kinds = cactual .arg_kinds [:prefix_len ],
1043
- arg_names = cactual .arg_names [:prefix_len ],
1044
- )
1045
- res .extend (
1046
- infer_callable_arguments_constraints (prefix , cactual_prefix , self .direction )
1047
- )
1048
-
1049
- template_ret_type , cactual_ret_type = template .ret_type , cactual .ret_type
1050
- if template .type_guard is not None :
1051
- template_ret_type = template .type_guard
1052
- if cactual .type_guard is not None :
1053
- cactual_ret_type = cactual .type_guard
1054
-
1055
- res .extend (infer_constraints (template_ret_type , cactual_ret_type , self .direction ))
1112
+ if param_spec_target is not None :
1113
+ res .append (Constraint (param_spec , neg_op (self .direction ), param_spec_target ))
1056
1114
if extra_tvars :
1057
1115
for c in res :
1058
1116
c .extra_tvars += cactual .variables
0 commit comments