Skip to content

Commit 6f9723c

Browse files
authored
Fine-grained: Support NewType and reset subtype caches (#4656)
NewType work highlighted an issue with subtype caches with stale information leaking, and this fixes that issue as well. We reset the subtype cache in two places: * When calculating the MRO; we reset caches of all base classes as well. * When merging a new version of a TypeInfo, which may have a different MRO; we reset all caches of base classes in the old MRO, as they might no longer be supertypes.
1 parent 21dcd80 commit 6f9723c

File tree

8 files changed

+156
-21
lines changed

8 files changed

+156
-21
lines changed

mypy/checker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,12 +874,12 @@ def is_trivial_body(self, block: Block) -> bool:
874874
body = block.body
875875

876876
# Skip a docstring
877-
if (isinstance(body[0], ExpressionStmt) and
877+
if (body and isinstance(body[0], ExpressionStmt) and
878878
isinstance(body[0].expr, (StrExpr, UnicodeExpr))):
879879
body = block.body[1:]
880880

881881
if len(body) == 0:
882-
# There's only a docstring.
882+
# There's only a docstring (or no body at all).
883883
return True
884884
elif len(body) > 1:
885885
return False

mypy/nodes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,6 +2080,11 @@ def is_cached_subtype_check(self, left: 'mypy.types.Instance',
20802080
return (left, right) in self._cache
20812081
return (left, right) in self._cache_proper
20822082

2083+
def reset_subtype_cache(self) -> None:
2084+
for item in self.mro:
2085+
item._cache = set()
2086+
item._cache_proper = set()
2087+
20832088
def __getitem__(self, name: str) -> 'SymbolTableNode':
20842089
n = self.get(name)
20852090
if n:
@@ -2116,6 +2121,7 @@ def calculate_mro(self) -> None:
21162121
self.is_enum = self._calculate_is_enum()
21172122
# The property of falling back to Any is inherited.
21182123
self.fallback_to_any = any(baseinfo.fallback_to_any for baseinfo in self.mro)
2124+
self.reset_subtype_cache()
21192125

21202126
def calculate_metaclass_type(self) -> 'Optional[mypy.types.Instance]':
21212127
declared = self.declared_metaclass

mypy/semanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2182,7 +2182,7 @@ def build_newtype_typeinfo(self, name: str, old_type: Type, base_type: Instance)
21822182
arg_types=[Instance(info, []), old_type],
21832183
arg_kinds=[arg.kind for arg in args],
21842184
arg_names=['self', 'item'],
2185-
ret_type=old_type,
2185+
ret_type=NoneTyp(),
21862186
fallback=self.named_type('__builtins__.function'),
21872187
name=name)
21882188
init_func = FuncDef('__init__', args, Block([]), typ=signature)

mypy/server/astmerge.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> None:
166166

167167
def visit_class_def(self, node: ClassDef) -> None:
168168
# TODO additional things?
169+
node.info = self.fixup_and_reset_typeinfo(node.info)
169170
node.defs.body = self.replace_statements(node.defs.body)
170-
node.info = self.fixup(node.info)
171171
info = node.info
172172
for tv in node.type_vars:
173173
self.process_type_var_def(tv)
@@ -214,7 +214,7 @@ def visit_ref_expr(self, node: RefExpr) -> None:
214214

215215
def visit_namedtuple_expr(self, node: NamedTupleExpr) -> None:
216216
super().visit_namedtuple_expr(node)
217-
node.info = self.fixup(node.info)
217+
node.info = self.fixup_and_reset_typeinfo(node.info)
218218
self.process_synthetic_type_info(node.info)
219219

220220
def visit_super_expr(self, node: SuperExpr) -> None:
@@ -229,7 +229,7 @@ def visit_call_expr(self, node: CallExpr) -> None:
229229

230230
def visit_newtype_expr(self, node: NewTypeExpr) -> None:
231231
if node.info:
232-
node.info = self.fixup(node.info)
232+
node.info = self.fixup_and_reset_typeinfo(node.info)
233233
self.process_synthetic_type_info(node.info)
234234
self.fixup_type(node.old_type)
235235
super().visit_newtype_expr(node)
@@ -240,11 +240,11 @@ def visit_lambda_expr(self, node: LambdaExpr) -> None:
240240

241241
def visit_typeddict_expr(self, node: TypedDictExpr) -> None:
242242
super().visit_typeddict_expr(node)
243-
node.info = self.fixup(node.info)
243+
node.info = self.fixup_and_reset_typeinfo(node.info)
244244
self.process_synthetic_type_info(node.info)
245245

246246
def visit_enum_call_expr(self, node: EnumCallExpr) -> None:
247-
node.info = self.fixup(node.info)
247+
node.info = self.fixup_and_reset_typeinfo(node.info)
248248
self.process_synthetic_type_info(node.info)
249249
super().visit_enum_call_expr(node)
250250

@@ -269,6 +269,19 @@ def fixup(self, node: SN) -> SN:
269269
return cast(SN, new)
270270
return node
271271

272+
def fixup_and_reset_typeinfo(self, node: TypeInfo) -> TypeInfo:
273+
"""Fix-up type info and reset subtype caches.
274+
275+
This needs to be called at least once per each merged TypeInfo, as otherwise we
276+
may leak stale caches.
277+
"""
278+
if node in self.replacements:
279+
# The subclass relationships may change, so reset all caches relevant to the
280+
# old MRO.
281+
new = cast(TypeInfo, self.replacements[node])
282+
new.reset_subtype_cache()
283+
return self.fixup(node)
284+
272285
def fixup_type(self, typ: Optional[Type]) -> None:
273286
if typ is not None:
274287
typ.accept(TypeReplaceVisitor(self.replacements))

mypy/server/deps.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
9292
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
9393
TupleExpr, ListExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
9494
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
95-
LDEF, MDEF, GDEF, FuncItem, TypeAliasExpr,
95+
LDEF, MDEF, GDEF, FuncItem, TypeAliasExpr, NewTypeExpr,
9696
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
9797
)
9898
from mypy.traverser import TraverserVisitor
@@ -211,18 +211,27 @@ def visit_class_def(self, o: ClassDef) -> None:
211211
# Add dependencies to type variables of a generic class.
212212
for tv in o.type_vars:
213213
self.add_dependency(make_trigger(tv.fullname), target)
214-
# Add dependencies to base types.
215-
for base in o.info.bases:
214+
self.process_type_info(o.info)
215+
super().visit_class_def(o)
216+
self.is_class = old_is_class
217+
self.scope.leave()
218+
219+
def visit_newtype_expr(self, o: NewTypeExpr) -> None:
220+
if o.info:
221+
self.scope.enter_class(o.info)
222+
self.process_type_info(o.info)
223+
self.scope.leave()
224+
225+
def process_type_info(self, info: TypeInfo) -> None:
226+
target = self.scope.current_full_target()
227+
for base in info.bases:
216228
self.add_type_dependencies(base, target=target)
217-
if o.info.tuple_type:
218-
self.add_type_dependencies(o.info.tuple_type, target=make_trigger(target))
219-
if o.info.typeddict_type:
220-
self.add_type_dependencies(o.info.typeddict_type, target=make_trigger(target))
229+
if info.tuple_type:
230+
self.add_type_dependencies(info.tuple_type, target=make_trigger(target))
231+
if info.typeddict_type:
232+
self.add_type_dependencies(info.typeddict_type, target=make_trigger(target))
221233
# TODO: Add dependencies based on remaining TypeInfo attributes.
222-
super().visit_class_def(o)
223234
self.add_type_alias_deps(self.scope.current_target())
224-
self.is_class = old_is_class
225-
info = o.info
226235
for name, node in info.names.items():
227236
if isinstance(node.node, Var):
228237
for base_info in non_trivial_bases(info):
@@ -236,7 +245,6 @@ def visit_class_def(self, o: ClassDef) -> None:
236245
target=make_trigger(info.fullname() + '.' + name))
237246
self.add_dependency(make_trigger(base_info.fullname() + '.__init__'),
238247
target=make_trigger(info.fullname() + '.__init__'))
239-
self.scope.leave()
240248

241249
def visit_import(self, o: Import) -> None:
242250
for id, as_id in o.ids:

test-data/unit/deps-statements.test

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,3 +655,20 @@ class C:
655655
<m.C> -> m.C
656656
<sys.platform> -> m
657657
<sys> -> m
658+
659+
[case testNewType]
660+
from typing import NewType
661+
from m import C
662+
663+
N = NewType('N', C)
664+
665+
def f(n: N) -> None:
666+
pass
667+
[file m.py]
668+
class C:
669+
x: int
670+
[out]
671+
<m.N> -> <m.f>, m, m.f
672+
<m.C.__init__> -> <m.N.__init__>
673+
<m.C.x> -> <m.N.x>
674+
<m.C> -> m, m.N

test-data/unit/diff.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,30 @@ B = Dict[str, S]
681681
__main__.A
682682
__main__.T
683683

684+
[case testNewType]
685+
from typing import NewType
686+
class C: pass
687+
class D: pass
688+
N1 = NewType('N1', C)
689+
N2 = NewType('N2', D)
690+
N3 = NewType('N3', C)
691+
class N4(C): pass
692+
[file next.py]
693+
from typing import NewType
694+
class C: pass
695+
class D(C): pass
696+
N1 = NewType('N1', C)
697+
N2 = NewType('N2', D)
698+
class N3(C): pass
699+
N4 = NewType('N4', C)
700+
[out]
701+
__main__.D
702+
__main__.N2
703+
__main__.N3
704+
__main__.N3.__init__
705+
__main__.N4
706+
__main__.N4.__init__
707+
684708
[case testChangeGenericBaseClassOnly]
685709
from typing import List
686710
class C(List[int]): pass

test-data/unit/fine-grained.test

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,8 @@ import a
15281528
[file a.py]
15291529
from typing import Dict, NewType
15301530

1531-
N = NewType('N', int)
1531+
class A: pass
1532+
N = NewType('N', A)
15321533

15331534
a: Dict[N, int]
15341535

@@ -1538,7 +1539,8 @@ def f(self, x: N) -> None:
15381539
[file a.py.2]
15391540
from typing import Dict, NewType # dummy change
15401541

1541-
N = NewType('N', int)
1542+
class A: pass
1543+
N = NewType('N', A)
15421544

15431545
a: Dict[N, int]
15441546

@@ -2498,6 +2500,71 @@ else:
24982500
[out]
24992501
==
25002502

2503+
[case testNewTypeDependencies1]
2504+
from a import N
2505+
2506+
def f(x: N) -> None:
2507+
x.y = 1
2508+
[file a.py]
2509+
from typing import NewType
2510+
from b import C
2511+
2512+
N = NewType('N', C)
2513+
[file b.py]
2514+
class C:
2515+
y: int
2516+
[file b.py.2]
2517+
class C:
2518+
y: str
2519+
[out]
2520+
==
2521+
main:4: error: Incompatible types in assignment (expression has type "int", variable has type "str")
2522+
2523+
[case testNewTypeDependencies2]
2524+
from a import N
2525+
from b import C, D
2526+
2527+
def f(x: C) -> None: pass
2528+
2529+
def g(x: N) -> None:
2530+
f(x)
2531+
[file a.py]
2532+
from typing import NewType
2533+
from b import D
2534+
2535+
N = NewType('N', D)
2536+
[file b.py]
2537+
class C: pass
2538+
class D(C): pass
2539+
[file b.py.2]
2540+
class C: pass
2541+
class D: pass
2542+
[out]
2543+
==
2544+
main:7: error: Argument 1 to "f" has incompatible type "N"; expected "C"
2545+
2546+
[case testNewTypeDependencies3]
2547+
from a import N
2548+
2549+
def f(x: N) -> None:
2550+
x.y
2551+
[file a.py]
2552+
from typing import NewType
2553+
from b import C
2554+
N = NewType('N', C)
2555+
[file a.py.2]
2556+
from typing import NewType
2557+
from b import D
2558+
N = NewType('N', D)
2559+
[file b.py]
2560+
class C:
2561+
y: int
2562+
class D:
2563+
pass
2564+
[out]
2565+
==
2566+
main:4: error: "N" has no attribute "y"
2567+
25012568
[case testNamedTupleWithinFunction]
25022569
from typing import NamedTuple
25032570
import b

0 commit comments

Comments
 (0)