@@ -1570,32 +1570,32 @@ def tvar_scope_frame(self) -> Iterator[None]:
1570
1570
yield
1571
1571
self .tvar_scope = old_scope
1572
1572
1573
- def find_type_var_likes (self , t : Type , include_callables : bool = True ) -> TypeVarLikeList :
1574
- return t .accept (
1575
- TypeVarLikeQuery (self .api , self .tvar_scope , include_callables = include_callables )
1576
- )
1577
-
1578
- def infer_type_variables (self , type : CallableType ) -> list [tuple [str , TypeVarLikeExpr ]]:
1579
- """Return list of unique type variables referred to in a callable."""
1580
- names : list [str ] = []
1581
- tvars : list [TypeVarLikeExpr ] = []
1573
+ def find_type_var_likes (self , t : Type ) -> TypeVarLikeList :
1574
+ visitor = FindTypeVarVisitor (self .api , self .tvar_scope )
1575
+ t .accept (visitor )
1576
+ return visitor .type_var_likes
1577
+
1578
+ def infer_type_variables (
1579
+ self , type : CallableType
1580
+ ) -> tuple [list [tuple [str , TypeVarLikeExpr ]], bool ]:
1581
+ """Infer type variables from a callable.
1582
+
1583
+ Return tuple with these items:
1584
+ - list of unique type variables referred to in a callable
1585
+ - whether there is a reference to the Self type
1586
+ """
1587
+ visitor = FindTypeVarVisitor (self .api , self .tvar_scope )
1582
1588
for arg in type .arg_types :
1583
- for name , tvar_expr in self .find_type_var_likes (arg ):
1584
- if name not in names :
1585
- names .append (name )
1586
- tvars .append (tvar_expr )
1589
+ arg .accept (visitor )
1590
+
1587
1591
# When finding type variables in the return type of a function, don't
1588
1592
# look inside Callable types. Type variables only appearing in
1589
1593
# functions in the return type belong to those functions, not the
1590
1594
# function we're currently analyzing.
1591
- for name , tvar_expr in self .find_type_var_likes (type .ret_type , include_callables = False ):
1592
- if name not in names :
1593
- names .append (name )
1594
- tvars .append (tvar_expr )
1595
+ visitor .include_callables = False
1596
+ type .ret_type .accept (visitor )
1595
1597
1596
- if not names :
1597
- return [] # Fast path
1598
- return list (zip (names , tvars ))
1598
+ return visitor .type_var_likes , visitor .has_self_type
1599
1599
1600
1600
def bind_function_type_variables (
1601
1601
self , fun_type : CallableType , defn : Context
@@ -1615,10 +1615,7 @@ def bind_function_type_variables(
1615
1615
binding = self .tvar_scope .bind_new (var .name , var_expr )
1616
1616
defs .append (binding )
1617
1617
return defs , has_self_type
1618
- typevars = self .infer_type_variables (fun_type )
1619
- has_self_type = find_self_type (
1620
- fun_type , lambda name : self .api .lookup_qualified (name , defn , suppress_errors = True )
1621
- )
1618
+ typevars , has_self_type = self .infer_type_variables (fun_type )
1622
1619
# Do not define a new type variable if already defined in scope.
1623
1620
typevars = [
1624
1621
(name , tvar ) for name , tvar in typevars if not self .is_defined_type_var (name , defn )
@@ -2062,67 +2059,6 @@ def flatten_tvars(lists: list[list[T]]) -> list[T]:
2062
2059
return result
2063
2060
2064
2061
2065
- class TypeVarLikeQuery (TypeQuery [TypeVarLikeList ]):
2066
- """Find TypeVar and ParamSpec references in an unbound type."""
2067
-
2068
- def __init__ (
2069
- self ,
2070
- api : SemanticAnalyzerCoreInterface ,
2071
- scope : TypeVarLikeScope ,
2072
- * ,
2073
- include_callables : bool = True ,
2074
- ) -> None :
2075
- super ().__init__ (flatten_tvars )
2076
- self .api = api
2077
- self .scope = scope
2078
- self .include_callables = include_callables
2079
- # Only include type variables in type aliases args. This would be anyway
2080
- # that case if we expand (as target variables would be overridden with args)
2081
- # and it may cause infinite recursion on invalid (diverging) recursive aliases.
2082
- self .skip_alias_target = True
2083
-
2084
- def _seems_like_callable (self , type : UnboundType ) -> bool :
2085
- if not type .args :
2086
- return False
2087
- return isinstance (type .args [0 ], (EllipsisType , TypeList , ParamSpecType ))
2088
-
2089
- def visit_unbound_type (self , t : UnboundType ) -> TypeVarLikeList :
2090
- name = t .name
2091
- node = None
2092
- # Special case P.args and P.kwargs for ParamSpecs only.
2093
- if name .endswith ("args" ):
2094
- if name .endswith (".args" ) or name .endswith (".kwargs" ):
2095
- base = "." .join (name .split ("." )[:- 1 ])
2096
- n = self .api .lookup_qualified (base , t )
2097
- if n is not None and isinstance (n .node , ParamSpecExpr ):
2098
- node = n
2099
- name = base
2100
- if node is None :
2101
- node = self .api .lookup_qualified (name , t )
2102
- if (
2103
- node
2104
- and isinstance (node .node , TypeVarLikeExpr )
2105
- and self .scope .get_binding (node ) is None
2106
- ):
2107
- assert isinstance (node .node , TypeVarLikeExpr )
2108
- return [(name , node .node )]
2109
- elif not self .include_callables and self ._seems_like_callable (t ):
2110
- return []
2111
- elif node and node .fullname in LITERAL_TYPE_NAMES :
2112
- return []
2113
- elif node and node .fullname in ANNOTATED_TYPE_NAMES and t .args :
2114
- # Don't query the second argument to Annotated for TypeVars
2115
- return self .query_types ([t .args [0 ]])
2116
- else :
2117
- return super ().visit_unbound_type (t )
2118
-
2119
- def visit_callable_type (self , t : CallableType ) -> TypeVarLikeList :
2120
- if self .include_callables :
2121
- return super ().visit_callable_type (t )
2122
- else :
2123
- return []
2124
-
2125
-
2126
2062
class DivergingAliasDetector (TrivialSyntheticTypeTranslator ):
2127
2063
"""See docstring of detect_diverging_alias() for details."""
2128
2064
@@ -2359,3 +2295,149 @@ def unknown_unpack(t: Type) -> bool:
2359
2295
if isinstance (unpacked , AnyType ) and unpacked .type_of_any == TypeOfAny .special_form :
2360
2296
return True
2361
2297
return False
2298
+
2299
+
2300
+ class FindTypeVarVisitor (SyntheticTypeVisitor [None ]):
2301
+ """Type visitor that looks for type variable types and self types."""
2302
+
2303
+ def __init__ (self , api : SemanticAnalyzerCoreInterface , scope : TypeVarLikeScope ) -> None :
2304
+ self .api = api
2305
+ self .scope = scope
2306
+ self .type_var_likes : list [tuple [str , TypeVarLikeExpr ]] = []
2307
+ self .has_self_type = False
2308
+ self .seen_aliases : set [TypeAliasType ] | None = None
2309
+ self .include_callables = True
2310
+
2311
+ def _seems_like_callable (self , type : UnboundType ) -> bool :
2312
+ if not type .args :
2313
+ return False
2314
+ return isinstance (type .args [0 ], (EllipsisType , TypeList , ParamSpecType ))
2315
+
2316
+ def visit_unbound_type (self , t : UnboundType ) -> None :
2317
+ name = t .name
2318
+ node = None
2319
+
2320
+ # Special case P.args and P.kwargs for ParamSpecs only.
2321
+ if name .endswith ("args" ):
2322
+ if name .endswith (".args" ) or name .endswith (".kwargs" ):
2323
+ base = "." .join (name .split ("." )[:- 1 ])
2324
+ n = self .api .lookup_qualified (base , t )
2325
+ if n is not None and isinstance (n .node , ParamSpecExpr ):
2326
+ node = n
2327
+ name = base
2328
+ if node is None :
2329
+ node = self .api .lookup_qualified (name , t )
2330
+ if node and node .fullname in SELF_TYPE_NAMES :
2331
+ self .has_self_type = True
2332
+ if (
2333
+ node
2334
+ and isinstance (node .node , TypeVarLikeExpr )
2335
+ and self .scope .get_binding (node ) is None
2336
+ ):
2337
+ if (name , node .node ) not in self .type_var_likes :
2338
+ self .type_var_likes .append ((name , node .node ))
2339
+ elif not self .include_callables and self ._seems_like_callable (t ):
2340
+ if find_self_type (
2341
+ t , lambda name : self .api .lookup_qualified (name , t , suppress_errors = True )
2342
+ ):
2343
+ self .has_self_type = True
2344
+ return
2345
+ elif node and node .fullname in LITERAL_TYPE_NAMES :
2346
+ return
2347
+ elif node and node .fullname in ANNOTATED_TYPE_NAMES and t .args :
2348
+ # Don't query the second argument to Annotated for TypeVars
2349
+ self .process_types ([t .args [0 ]])
2350
+ elif t .args :
2351
+ self .process_types (t .args )
2352
+
2353
+ def visit_type_list (self , t : TypeList ) -> None :
2354
+ self .process_types (t .items )
2355
+
2356
+ def visit_callable_argument (self , t : CallableArgument ) -> None :
2357
+ t .typ .accept (self )
2358
+
2359
+ def visit_any (self , t : AnyType ) -> None :
2360
+ pass
2361
+
2362
+ def visit_uninhabited_type (self , t : UninhabitedType ) -> None :
2363
+ pass
2364
+
2365
+ def visit_none_type (self , t : NoneType ) -> None :
2366
+ pass
2367
+
2368
+ def visit_erased_type (self , t : ErasedType ) -> None :
2369
+ pass
2370
+
2371
+ def visit_deleted_type (self , t : DeletedType ) -> None :
2372
+ pass
2373
+
2374
+ def visit_type_var (self , t : TypeVarType ) -> None :
2375
+ self .process_types ([t .upper_bound , t .default ] + t .values )
2376
+
2377
+ def visit_param_spec (self , t : ParamSpecType ) -> None :
2378
+ self .process_types ([t .upper_bound , t .default ])
2379
+
2380
+ def visit_type_var_tuple (self , t : TypeVarTupleType ) -> None :
2381
+ self .process_types ([t .upper_bound , t .default ])
2382
+
2383
+ def visit_unpack_type (self , t : UnpackType ) -> None :
2384
+ self .process_types ([t .type ])
2385
+
2386
+ def visit_parameters (self , t : Parameters ) -> None :
2387
+ self .process_types (t .arg_types )
2388
+
2389
+ def visit_partial_type (self , t : PartialType ) -> None :
2390
+ pass
2391
+
2392
+ def visit_instance (self , t : Instance ) -> None :
2393
+ self .process_types (t .args )
2394
+
2395
+ def visit_callable_type (self , t : CallableType ) -> None :
2396
+ # FIX generics
2397
+ self .process_types (t .arg_types )
2398
+ t .ret_type .accept (self )
2399
+
2400
+ def visit_tuple_type (self , t : TupleType ) -> None :
2401
+ self .process_types (t .items )
2402
+
2403
+ def visit_typeddict_type (self , t : TypedDictType ) -> None :
2404
+ self .process_types (list (t .items .values ()))
2405
+
2406
+ def visit_raw_expression_type (self , t : RawExpressionType ) -> None :
2407
+ pass
2408
+
2409
+ def visit_literal_type (self , t : LiteralType ) -> None :
2410
+ pass
2411
+
2412
+ def visit_union_type (self , t : UnionType ) -> None :
2413
+ self .process_types (t .items )
2414
+
2415
+ def visit_overloaded (self , t : Overloaded ) -> None :
2416
+ self .process_types (t .items ) # type: ignore[arg-type]
2417
+
2418
+ def visit_type_type (self , t : TypeType ) -> None :
2419
+ t .item .accept (self )
2420
+
2421
+ def visit_ellipsis_type (self , t : EllipsisType ) -> None :
2422
+ pass
2423
+
2424
+ def visit_placeholder_type (self , t : PlaceholderType ) -> None :
2425
+ return self .process_types (t .args )
2426
+
2427
+ def visit_type_alias_type (self , t : TypeAliasType ) -> None :
2428
+ # Skip type aliases in already visited types to avoid infinite recursion.
2429
+ if self .seen_aliases is None :
2430
+ self .seen_aliases = set ()
2431
+ elif t in self .seen_aliases :
2432
+ return
2433
+ self .seen_aliases .add (t )
2434
+ self .process_types (t .args )
2435
+
2436
+ def process_types (self , types : list [Type ] | tuple [Type , ...]) -> None :
2437
+ # Redundant type check helps mypyc.
2438
+ if isinstance (types , list ):
2439
+ for t in types :
2440
+ t .accept (self )
2441
+ else :
2442
+ for t in types :
2443
+ t .accept (self )
0 commit comments