Skip to content

Commit 761965d

Browse files
authored
Speed up finding function type variables (#16562)
Merge two visitors into a single visitor that is a bit more optimized than the old visitors. This speeds ups tests, in particular -- `mypy/test/testcheck.py` is about 4% faster and `mypy/test/testpythoneval.py` is about 3% faster. Also self-check is about 1% faster, both interpreted and compiled. This adds more code, but the new code is largely boilerplate, so the difficulty of maintenance seems roughly the same.
1 parent f79ae69 commit 761965d

File tree

2 files changed

+176
-89
lines changed

2 files changed

+176
-89
lines changed

mypy/semanal.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,9 @@
224224
from mypy.tvar_scope import TypeVarLikeScope
225225
from mypy.typeanal import (
226226
SELF_TYPE_NAMES,
227+
FindTypeVarVisitor,
227228
TypeAnalyser,
228229
TypeVarLikeList,
229-
TypeVarLikeQuery,
230230
analyze_type_alias,
231231
check_for_explicit_any,
232232
detect_diverging_alias,
@@ -2034,6 +2034,11 @@ def analyze_unbound_tvar_impl(
20342034
assert isinstance(sym.node, TypeVarExpr)
20352035
return t.name, sym.node
20362036

2037+
def find_type_var_likes(self, t: Type) -> TypeVarLikeList:
2038+
visitor = FindTypeVarVisitor(self, self.tvar_scope)
2039+
t.accept(visitor)
2040+
return visitor.type_var_likes
2041+
20372042
def get_all_bases_tvars(
20382043
self, base_type_exprs: list[Expression], removed: list[int]
20392044
) -> TypeVarLikeList:
@@ -2046,7 +2051,7 @@ def get_all_bases_tvars(
20462051
except TypeTranslationError:
20472052
# This error will be caught later.
20482053
continue
2049-
base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope))
2054+
base_tvars = self.find_type_var_likes(base)
20502055
tvars.extend(base_tvars)
20512056
return remove_dups(tvars)
20522057

@@ -2064,7 +2069,7 @@ def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLi
20642069
except TypeTranslationError:
20652070
# This error will be caught later.
20662071
continue
2067-
base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope))
2072+
base_tvars = self.find_type_var_likes(base)
20682073
tvars.extend(base_tvars)
20692074
tvars = remove_dups(tvars) # Variables are defined in order of textual appearance.
20702075
tvar_defs = []
@@ -3490,7 +3495,7 @@ def analyze_alias(
34903495
)
34913496
return None, [], set(), [], False
34923497

3493-
found_type_vars = typ.accept(TypeVarLikeQuery(self, self.tvar_scope))
3498+
found_type_vars = self.find_type_var_likes(typ)
34943499
tvar_defs: list[TypeVarLikeType] = []
34953500
namespace = self.qualified_name(name)
34963501
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):

mypy/typeanal.py

Lines changed: 167 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,32 +1570,32 @@ def tvar_scope_frame(self) -> Iterator[None]:
15701570
yield
15711571
self.tvar_scope = old_scope
15721572

1573-
def find_type_var_likes(self, t: Type, include_callables: bool = True) -> TypeVarLikeList:
1574-
return t.accept(
1575-
TypeVarLikeQuery(self.api, self.tvar_scope, include_callables=include_callables)
1576-
)
1577-
1578-
def infer_type_variables(self, type: CallableType) -> list[tuple[str, TypeVarLikeExpr]]:
1579-
"""Return list of unique type variables referred to in a callable."""
1580-
names: list[str] = []
1581-
tvars: list[TypeVarLikeExpr] = []
1573+
def find_type_var_likes(self, t: Type) -> TypeVarLikeList:
1574+
visitor = FindTypeVarVisitor(self.api, self.tvar_scope)
1575+
t.accept(visitor)
1576+
return visitor.type_var_likes
1577+
1578+
def infer_type_variables(
1579+
self, type: CallableType
1580+
) -> tuple[list[tuple[str, TypeVarLikeExpr]], bool]:
1581+
"""Infer type variables from a callable.
1582+
1583+
Return tuple with these items:
1584+
- list of unique type variables referred to in a callable
1585+
- whether there is a reference to the Self type
1586+
"""
1587+
visitor = FindTypeVarVisitor(self.api, self.tvar_scope)
15821588
for arg in type.arg_types:
1583-
for name, tvar_expr in self.find_type_var_likes(arg):
1584-
if name not in names:
1585-
names.append(name)
1586-
tvars.append(tvar_expr)
1589+
arg.accept(visitor)
1590+
15871591
# When finding type variables in the return type of a function, don't
15881592
# look inside Callable types. Type variables only appearing in
15891593
# functions in the return type belong to those functions, not the
15901594
# function we're currently analyzing.
1591-
for name, tvar_expr in self.find_type_var_likes(type.ret_type, include_callables=False):
1592-
if name not in names:
1593-
names.append(name)
1594-
tvars.append(tvar_expr)
1595+
visitor.include_callables = False
1596+
type.ret_type.accept(visitor)
15951597

1596-
if not names:
1597-
return [] # Fast path
1598-
return list(zip(names, tvars))
1598+
return visitor.type_var_likes, visitor.has_self_type
15991599

16001600
def bind_function_type_variables(
16011601
self, fun_type: CallableType, defn: Context
@@ -1615,10 +1615,7 @@ def bind_function_type_variables(
16151615
binding = self.tvar_scope.bind_new(var.name, var_expr)
16161616
defs.append(binding)
16171617
return defs, has_self_type
1618-
typevars = self.infer_type_variables(fun_type)
1619-
has_self_type = find_self_type(
1620-
fun_type, lambda name: self.api.lookup_qualified(name, defn, suppress_errors=True)
1621-
)
1618+
typevars, has_self_type = self.infer_type_variables(fun_type)
16221619
# Do not define a new type variable if already defined in scope.
16231620
typevars = [
16241621
(name, tvar) for name, tvar in typevars if not self.is_defined_type_var(name, defn)
@@ -2062,67 +2059,6 @@ def flatten_tvars(lists: list[list[T]]) -> list[T]:
20622059
return result
20632060

20642061

2065-
class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]):
2066-
"""Find TypeVar and ParamSpec references in an unbound type."""
2067-
2068-
def __init__(
2069-
self,
2070-
api: SemanticAnalyzerCoreInterface,
2071-
scope: TypeVarLikeScope,
2072-
*,
2073-
include_callables: bool = True,
2074-
) -> None:
2075-
super().__init__(flatten_tvars)
2076-
self.api = api
2077-
self.scope = scope
2078-
self.include_callables = include_callables
2079-
# Only include type variables in type aliases args. This would be anyway
2080-
# that case if we expand (as target variables would be overridden with args)
2081-
# and it may cause infinite recursion on invalid (diverging) recursive aliases.
2082-
self.skip_alias_target = True
2083-
2084-
def _seems_like_callable(self, type: UnboundType) -> bool:
2085-
if not type.args:
2086-
return False
2087-
return isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType))
2088-
2089-
def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList:
2090-
name = t.name
2091-
node = None
2092-
# Special case P.args and P.kwargs for ParamSpecs only.
2093-
if name.endswith("args"):
2094-
if name.endswith(".args") or name.endswith(".kwargs"):
2095-
base = ".".join(name.split(".")[:-1])
2096-
n = self.api.lookup_qualified(base, t)
2097-
if n is not None and isinstance(n.node, ParamSpecExpr):
2098-
node = n
2099-
name = base
2100-
if node is None:
2101-
node = self.api.lookup_qualified(name, t)
2102-
if (
2103-
node
2104-
and isinstance(node.node, TypeVarLikeExpr)
2105-
and self.scope.get_binding(node) is None
2106-
):
2107-
assert isinstance(node.node, TypeVarLikeExpr)
2108-
return [(name, node.node)]
2109-
elif not self.include_callables and self._seems_like_callable(t):
2110-
return []
2111-
elif node and node.fullname in LITERAL_TYPE_NAMES:
2112-
return []
2113-
elif node and node.fullname in ANNOTATED_TYPE_NAMES and t.args:
2114-
# Don't query the second argument to Annotated for TypeVars
2115-
return self.query_types([t.args[0]])
2116-
else:
2117-
return super().visit_unbound_type(t)
2118-
2119-
def visit_callable_type(self, t: CallableType) -> TypeVarLikeList:
2120-
if self.include_callables:
2121-
return super().visit_callable_type(t)
2122-
else:
2123-
return []
2124-
2125-
21262062
class DivergingAliasDetector(TrivialSyntheticTypeTranslator):
21272063
"""See docstring of detect_diverging_alias() for details."""
21282064

@@ -2359,3 +2295,149 @@ def unknown_unpack(t: Type) -> bool:
23592295
if isinstance(unpacked, AnyType) and unpacked.type_of_any == TypeOfAny.special_form:
23602296
return True
23612297
return False
2298+
2299+
2300+
class FindTypeVarVisitor(SyntheticTypeVisitor[None]):
2301+
"""Type visitor that looks for type variable types and self types."""
2302+
2303+
def __init__(self, api: SemanticAnalyzerCoreInterface, scope: TypeVarLikeScope) -> None:
2304+
self.api = api
2305+
self.scope = scope
2306+
self.type_var_likes: list[tuple[str, TypeVarLikeExpr]] = []
2307+
self.has_self_type = False
2308+
self.seen_aliases: set[TypeAliasType] | None = None
2309+
self.include_callables = True
2310+
2311+
def _seems_like_callable(self, type: UnboundType) -> bool:
2312+
if not type.args:
2313+
return False
2314+
return isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType))
2315+
2316+
def visit_unbound_type(self, t: UnboundType) -> None:
2317+
name = t.name
2318+
node = None
2319+
2320+
# Special case P.args and P.kwargs for ParamSpecs only.
2321+
if name.endswith("args"):
2322+
if name.endswith(".args") or name.endswith(".kwargs"):
2323+
base = ".".join(name.split(".")[:-1])
2324+
n = self.api.lookup_qualified(base, t)
2325+
if n is not None and isinstance(n.node, ParamSpecExpr):
2326+
node = n
2327+
name = base
2328+
if node is None:
2329+
node = self.api.lookup_qualified(name, t)
2330+
if node and node.fullname in SELF_TYPE_NAMES:
2331+
self.has_self_type = True
2332+
if (
2333+
node
2334+
and isinstance(node.node, TypeVarLikeExpr)
2335+
and self.scope.get_binding(node) is None
2336+
):
2337+
if (name, node.node) not in self.type_var_likes:
2338+
self.type_var_likes.append((name, node.node))
2339+
elif not self.include_callables and self._seems_like_callable(t):
2340+
if find_self_type(
2341+
t, lambda name: self.api.lookup_qualified(name, t, suppress_errors=True)
2342+
):
2343+
self.has_self_type = True
2344+
return
2345+
elif node and node.fullname in LITERAL_TYPE_NAMES:
2346+
return
2347+
elif node and node.fullname in ANNOTATED_TYPE_NAMES and t.args:
2348+
# Don't query the second argument to Annotated for TypeVars
2349+
self.process_types([t.args[0]])
2350+
elif t.args:
2351+
self.process_types(t.args)
2352+
2353+
def visit_type_list(self, t: TypeList) -> None:
2354+
self.process_types(t.items)
2355+
2356+
def visit_callable_argument(self, t: CallableArgument) -> None:
2357+
t.typ.accept(self)
2358+
2359+
def visit_any(self, t: AnyType) -> None:
2360+
pass
2361+
2362+
def visit_uninhabited_type(self, t: UninhabitedType) -> None:
2363+
pass
2364+
2365+
def visit_none_type(self, t: NoneType) -> None:
2366+
pass
2367+
2368+
def visit_erased_type(self, t: ErasedType) -> None:
2369+
pass
2370+
2371+
def visit_deleted_type(self, t: DeletedType) -> None:
2372+
pass
2373+
2374+
def visit_type_var(self, t: TypeVarType) -> None:
2375+
self.process_types([t.upper_bound, t.default] + t.values)
2376+
2377+
def visit_param_spec(self, t: ParamSpecType) -> None:
2378+
self.process_types([t.upper_bound, t.default])
2379+
2380+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
2381+
self.process_types([t.upper_bound, t.default])
2382+
2383+
def visit_unpack_type(self, t: UnpackType) -> None:
2384+
self.process_types([t.type])
2385+
2386+
def visit_parameters(self, t: Parameters) -> None:
2387+
self.process_types(t.arg_types)
2388+
2389+
def visit_partial_type(self, t: PartialType) -> None:
2390+
pass
2391+
2392+
def visit_instance(self, t: Instance) -> None:
2393+
self.process_types(t.args)
2394+
2395+
def visit_callable_type(self, t: CallableType) -> None:
2396+
# FIX generics
2397+
self.process_types(t.arg_types)
2398+
t.ret_type.accept(self)
2399+
2400+
def visit_tuple_type(self, t: TupleType) -> None:
2401+
self.process_types(t.items)
2402+
2403+
def visit_typeddict_type(self, t: TypedDictType) -> None:
2404+
self.process_types(list(t.items.values()))
2405+
2406+
def visit_raw_expression_type(self, t: RawExpressionType) -> None:
2407+
pass
2408+
2409+
def visit_literal_type(self, t: LiteralType) -> None:
2410+
pass
2411+
2412+
def visit_union_type(self, t: UnionType) -> None:
2413+
self.process_types(t.items)
2414+
2415+
def visit_overloaded(self, t: Overloaded) -> None:
2416+
self.process_types(t.items) # type: ignore[arg-type]
2417+
2418+
def visit_type_type(self, t: TypeType) -> None:
2419+
t.item.accept(self)
2420+
2421+
def visit_ellipsis_type(self, t: EllipsisType) -> None:
2422+
pass
2423+
2424+
def visit_placeholder_type(self, t: PlaceholderType) -> None:
2425+
return self.process_types(t.args)
2426+
2427+
def visit_type_alias_type(self, t: TypeAliasType) -> None:
2428+
# Skip type aliases in already visited types to avoid infinite recursion.
2429+
if self.seen_aliases is None:
2430+
self.seen_aliases = set()
2431+
elif t in self.seen_aliases:
2432+
return
2433+
self.seen_aliases.add(t)
2434+
self.process_types(t.args)
2435+
2436+
def process_types(self, types: list[Type] | tuple[Type, ...]) -> None:
2437+
# Redundant type check helps mypyc.
2438+
if isinstance(types, list):
2439+
for t in types:
2440+
t.accept(self)
2441+
else:
2442+
for t in types:
2443+
t.accept(self)

0 commit comments

Comments
 (0)