Skip to content

Commit b99f43f

Browse files
authored
Fix crash due to checking type variable values too early (#4384)
Move type variable checks which use subtype and type sameness checks to happen at the end of semantic analysis. The implementation also adds the concept of priorities to semantic analysis patch callbacks. Callback calls are sorted by the priority. We resolve forward references and calculate fallbacks before checking type variable values, as otherwise the latter could see incomplete types and crash. Fixes #4200.
1 parent 105b78e commit b99f43f

8 files changed

+130
-78
lines changed

mypy/build.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,14 +1853,14 @@ def parse_file(self) -> None:
18531853

18541854
def semantic_analysis(self) -> None:
18551855
assert self.tree is not None, "Internal error: method must be called on parsed file only"
1856-
patches = [] # type: List[Callable[[], None]]
1856+
patches = [] # type: List[Tuple[int, Callable[[], None]]]
18571857
with self.wrap_context():
18581858
self.manager.semantic_analyzer.visit_file(self.tree, self.xpath, self.options, patches)
18591859
self.patches = patches
18601860

18611861
def semantic_analysis_pass_three(self) -> None:
18621862
assert self.tree is not None, "Internal error: method must be called on parsed file only"
1863-
patches = [] # type: List[Callable[[], None]]
1863+
patches = [] # type: List[Tuple[int, Callable[[], None]]]
18641864
with self.wrap_context():
18651865
self.manager.semantic_analyzer_pass3.visit_file(self.tree, self.xpath,
18661866
self.options, patches)
@@ -1869,7 +1869,8 @@ def semantic_analysis_pass_three(self) -> None:
18691869
self.patches = patches + self.patches
18701870

18711871
def semantic_analysis_apply_patches(self) -> None:
1872-
for patch_func in self.patches:
1872+
patches_by_priority = sorted(self.patches, key=lambda x: x[0])
1873+
for priority, patch_func in patches_by_priority:
18731874
patch_func()
18741875

18751876
def type_check_first_pass(self) -> None:

mypy/semanal.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from mypy.plugin import Plugin, ClassDefContext, SemanticAnalyzerPluginInterface
8585
from mypy import join
8686
from mypy.util import get_prefix
87+
from mypy.semanal_shared import PRIORITY_FALLBACKS
8788

8889

8990
T = TypeVar('T')
@@ -258,11 +259,12 @@ def __init__(self,
258259
self.recurse_into_functions = True
259260

260261
def visit_file(self, file_node: MypyFile, fnam: str, options: Options,
261-
patches: List[Callable[[], None]]) -> None:
262+
patches: List[Tuple[int, Callable[[], None]]]) -> None:
262263
"""Run semantic analysis phase 2 over a file.
263264
264-
Add callbacks by mutating the patches list argument. They will be called
265-
after all semantic analysis phases but before type checking.
265+
Add (priority, callback) pairs by mutating the 'patches' list argument. They
266+
will be called after all semantic analysis phases but before type checking,
267+
lowest priority values first.
266268
"""
267269
self.recurse_into_functions = True
268270
self.options = options
@@ -2454,7 +2456,7 @@ def patch() -> None:
24542456
# We can't calculate the complete fallback type until after semantic
24552457
# analysis, since otherwise MROs might be incomplete. Postpone a callback
24562458
# function that patches the fallback.
2457-
self.patches.append(patch)
2459+
self.patches.append((PRIORITY_FALLBACKS, patch))
24582460

24592461
def add_field(var: Var, is_initialized_in_class: bool = False,
24602462
is_property: bool = False) -> None:
@@ -2693,7 +2695,7 @@ def patch() -> None:
26932695
# We can't calculate the complete fallback type until after semantic
26942696
# analysis, since otherwise MROs might be incomplete. Postpone a callback
26952697
# function that patches the fallback.
2696-
self.patches.append(patch)
2698+
self.patches.append((PRIORITY_FALLBACKS, patch))
26972699
return info
26982700

26992701
def check_classvar(self, s: AssignmentStmt) -> None:

mypy/semanal_pass3.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
from collections import OrderedDict
13-
from typing import Dict, List, Callable, Optional, Union, Set, cast
13+
from typing import Dict, List, Callable, Optional, Union, Set, cast, Tuple
1414

1515
from mypy import messages, experiments
1616
from mypy.nodes import (
@@ -28,6 +28,9 @@
2828
from mypy.traverser import TraverserVisitor
2929
from mypy.typeanal import TypeAnalyserPass3, collect_any_types
3030
from mypy.typevars import has_no_typevars
31+
from mypy.semanal_shared import PRIORITY_FORWARD_REF, PRIORITY_TYPEVAR_VALUES
32+
from mypy.subtypes import is_subtype
33+
from mypy.sametypes import is_same_type
3134
import mypy.semanal
3235

3336

@@ -48,7 +51,7 @@ def __init__(self, modules: Dict[str, MypyFile], errors: Errors,
4851
self.recurse_into_functions = True
4952

5053
def visit_file(self, file_node: MypyFile, fnam: str, options: Options,
51-
patches: List[Callable[[], None]]) -> None:
54+
patches: List[Tuple[int, Callable[[], None]]]) -> None:
5255
self.recurse_into_functions = True
5356
self.errors.set_file(fnam, file_node.fullname())
5457
self.options = options
@@ -349,12 +352,7 @@ def analyze(self, type: Optional[Type], node: Union[Node, SymbolTableNode],
349352
analyzer = self.make_type_analyzer(indicator)
350353
type.accept(analyzer)
351354
self.check_for_omitted_generics(type)
352-
if indicator.get('forward') or indicator.get('synthetic'):
353-
def patch() -> None:
354-
self.perform_transform(node,
355-
lambda tp: tp.accept(ForwardReferenceResolver(self.fail,
356-
node, warn)))
357-
self.patches.append(patch)
355+
self.generate_type_patches(node, indicator, warn)
358356

359357
def analyze_types(self, types: List[Type], node: Node) -> None:
360358
# Similar to above but for nodes with multiple types.
@@ -363,12 +361,24 @@ def analyze_types(self, types: List[Type], node: Node) -> None:
363361
analyzer = self.make_type_analyzer(indicator)
364362
type.accept(analyzer)
365363
self.check_for_omitted_generics(type)
364+
self.generate_type_patches(node, indicator, warn=False)
365+
366+
def generate_type_patches(self,
367+
node: Union[Node, SymbolTableNode],
368+
indicator: Dict[str, bool],
369+
warn: bool) -> None:
366370
if indicator.get('forward') or indicator.get('synthetic'):
367371
def patch() -> None:
368372
self.perform_transform(node,
369373
lambda tp: tp.accept(ForwardReferenceResolver(self.fail,
370-
node, warn=False)))
371-
self.patches.append(patch)
374+
node, warn)))
375+
self.patches.append((PRIORITY_FORWARD_REF, patch))
376+
if indicator.get('typevar'):
377+
def patch() -> None:
378+
self.perform_transform(node,
379+
lambda tp: tp.accept(TypeVariableChecker(self.fail)))
380+
381+
self.patches.append((PRIORITY_TYPEVAR_VALUES, patch))
372382

373383
def analyze_info(self, info: TypeInfo) -> None:
374384
# Similar to above but for nodes with synthetic TypeInfos (NamedTuple and NewType).
@@ -387,7 +397,8 @@ def make_type_analyzer(self, indicator: Dict[str, bool]) -> TypeAnalyserPass3:
387397
self.sem.plugin,
388398
self.options,
389399
self.is_typeshed_file,
390-
indicator)
400+
indicator,
401+
self.patches)
391402

392403
def check_for_omitted_generics(self, typ: Type) -> None:
393404
if not self.options.disallow_any_generics or self.is_typeshed_file:
@@ -606,3 +617,58 @@ def visit_type_type(self, t: TypeType) -> Type:
606617
if self.check_recursion(t):
607618
return AnyType(TypeOfAny.from_error)
608619
return super().visit_type_type(t)
620+
621+
622+
class TypeVariableChecker(TypeTranslator):
623+
"""Visitor that checks that type variables in generic types have valid values.
624+
625+
Note: This must be run at the end of semantic analysis when MROs are
626+
complete and forward references have been resolved.
627+
628+
This does two things:
629+
630+
- If type variable in C has a value restriction, check that X in C[X] conforms
631+
to the restriction.
632+
- If type variable in C has a non-default upper bound, check that X in C[X]
633+
conforms to the upper bound.
634+
635+
(This doesn't need to be a type translator, but it simplifies the implementation.)
636+
"""
637+
638+
def __init__(self, fail: Callable[[str, Context], None]) -> None:
639+
self.fail = fail
640+
641+
def visit_instance(self, t: Instance) -> Type:
642+
info = t.type
643+
for (i, arg), tvar in zip(enumerate(t.args), info.defn.type_vars):
644+
if tvar.values:
645+
if isinstance(arg, TypeVarType):
646+
arg_values = arg.values
647+
if not arg_values:
648+
self.fail('Type variable "{}" not valid as type '
649+
'argument value for "{}"'.format(
650+
arg.name, info.name()), t)
651+
continue
652+
else:
653+
arg_values = [arg]
654+
self.check_type_var_values(info, arg_values, tvar.name, tvar.values, i + 1, t)
655+
if not is_subtype(arg, tvar.upper_bound):
656+
self.fail('Type argument "{}" of "{}" must be '
657+
'a subtype of "{}"'.format(
658+
arg, info.name(), tvar.upper_bound), t)
659+
return t
660+
661+
def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str,
662+
valids: List[Type], arg_number: int, context: Context) -> None:
663+
for actual in actuals:
664+
if (not isinstance(actual, AnyType) and
665+
not any(is_same_type(actual, value)
666+
for value in valids)):
667+
if len(actuals) > 1 or not isinstance(actual, Instance):
668+
self.fail('Invalid type argument value for "{}"'.format(
669+
type.name()), context)
670+
else:
671+
class_name = '"{}"'.format(type.name())
672+
actual_type_name = '"{}"'.format(actual.type.name())
673+
self.fail(messages.INCOMPATIBLE_TYPEVAR_VALUE.format(
674+
arg_name, class_name, actual_type_name), context)

mypy/semanal_shared.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Shared definitions used by different parts of semantic analysis."""
2+
3+
# Priorities for ordering of patches within the final "patch" phase of semantic analysis
4+
# (after pass 3):
5+
6+
# Fix forward references (needs to happen first)
7+
PRIORITY_FORWARD_REF = 0
8+
# Fix fallbacks (does joins)
9+
PRIORITY_FALLBACKS = 1
10+
# Checks type var values (does subtype checks)
11+
PRIORITY_TYPEVAR_VALUES = 2

mypy/typeanal.py

Lines changed: 16 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Semantic analysis of types"""
22

33
from collections import OrderedDict
4-
from typing import Callable, List, Optional, Set, Tuple, Iterator, TypeVar, Iterable, Dict
4+
from typing import Callable, List, Optional, Set, Tuple, Iterator, TypeVar, Iterable, Dict, Union
55
from itertools import chain
66

77
from contextlib import contextmanager
@@ -14,19 +14,18 @@
1414
Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType,
1515
CallableType, NoneTyp, DeletedType, TypeList, TypeVarDef, TypeVisitor, SyntheticTypeVisitor,
1616
StarType, PartialType, EllipsisType, UninhabitedType, TypeType, get_typ_args, set_typ_args,
17-
CallableArgument, get_type_vars, TypeQuery, union_items, TypeOfAny, ForwardRef, Overloaded
17+
CallableArgument, get_type_vars, TypeQuery, union_items, TypeOfAny, ForwardRef, Overloaded,
18+
TypeTranslator
1819
)
1920

2021
from mypy.nodes import (
2122
TVAR, TYPE_ALIAS, UNBOUND_IMPORTED, TypeInfo, Context, SymbolTableNode, Var, Expression,
2223
IndexExpr, RefExpr, nongen_builtins, check_arg_names, check_arg_kinds, ARG_POS, ARG_NAMED,
2324
ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, FuncDef, CallExpr, NameExpr,
24-
Decorator
25+
Decorator, Node
2526
)
2627
from mypy.tvar_scope import TypeVarScope
27-
from mypy.sametypes import is_same_type
2828
from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError
29-
from mypy.subtypes import is_subtype
3029
from mypy.plugin import Plugin, TypeAnalyzerPluginInterface, AnalyzeTypeContext
3130
from mypy import nodes, messages
3231

@@ -656,7 +655,8 @@ def __init__(self,
656655
plugin: Plugin,
657656
options: Options,
658657
is_typeshed_stub: bool,
659-
indicator: Dict[str, bool]) -> None:
658+
indicator: Dict[str, bool],
659+
patches: List[Tuple[int, Callable[[], None]]]) -> None:
660660
self.lookup_func = lookup_func
661661
self.lookup_fqn_func = lookup_fqn_func
662662
self.fail = fail_func
@@ -665,6 +665,7 @@ def __init__(self,
665665
self.plugin = plugin
666666
self.is_typeshed_stub = is_typeshed_stub
667667
self.indicator = indicator
668+
self.patches = patches
668669

669670
def visit_instance(self, t: Instance) -> None:
670671
info = t.type
@@ -707,64 +708,21 @@ def visit_instance(self, t: Instance) -> None:
707708
t.args = [AnyType(TypeOfAny.from_error) for _ in info.type_vars]
708709
t.invalid = True
709710
elif info.defn.type_vars:
710-
# Check type argument values.
711-
# TODO: Calling is_subtype and is_same_types in semantic analysis is a bad idea
712-
for (i, arg), tvar in zip(enumerate(t.args), info.defn.type_vars):
713-
if tvar.values:
714-
if isinstance(arg, TypeVarType):
715-
arg_values = arg.values
716-
if not arg_values:
717-
self.fail('Type variable "{}" not valid as type '
718-
'argument value for "{}"'.format(
719-
arg.name, info.name()), t)
720-
continue
721-
else:
722-
arg_values = [arg]
723-
self.check_type_var_values(info, arg_values, tvar.name, tvar.values, i + 1, t)
724-
# TODO: These hacks will be not necessary when this will be moved to later stage.
725-
arg = self.resolve_type(arg)
726-
bound = self.resolve_type(tvar.upper_bound)
727-
if not is_subtype(arg, bound):
728-
self.fail('Type argument "{}" of "{}" must be '
729-
'a subtype of "{}"'.format(
730-
arg, info.name(), bound), t)
711+
# Check type argument values. This is postponed to the end of semantic analysis
712+
# since we need full MROs and resolved forward references.
713+
for tvar in info.defn.type_vars:
714+
if (tvar.values
715+
or not isinstance(tvar.upper_bound, Instance)
716+
or tvar.upper_bound.type.fullname() != 'builtins.object'):
717+
# Some restrictions on type variable. These can only be checked later
718+
# after we have final MROs and forward references have been resolved.
719+
self.indicator['typevar'] = True
731720
for arg in t.args:
732721
arg.accept(self)
733722
if info.is_newtype:
734723
for base in info.bases:
735724
base.accept(self)
736725

737-
def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str,
738-
valids: List[Type], arg_number: int, context: Context) -> None:
739-
for actual in actuals:
740-
actual = self.resolve_type(actual)
741-
if (not isinstance(actual, AnyType) and
742-
not any(is_same_type(actual, self.resolve_type(value))
743-
for value in valids)):
744-
if len(actuals) > 1 or not isinstance(actual, Instance):
745-
self.fail('Invalid type argument value for "{}"'.format(
746-
type.name()), context)
747-
else:
748-
class_name = '"{}"'.format(type.name())
749-
actual_type_name = '"{}"'.format(actual.type.name())
750-
self.fail(messages.INCOMPATIBLE_TYPEVAR_VALUE.format(
751-
arg_name, class_name, actual_type_name), context)
752-
753-
def resolve_type(self, tp: Type) -> Type:
754-
# This helper is only needed while is_subtype and is_same_type are
755-
# called in third pass. This can be removed when TODO in visit_instance is fixed.
756-
if isinstance(tp, ForwardRef):
757-
if tp.resolved is None:
758-
return tp.unbound
759-
tp = tp.resolved
760-
if isinstance(tp, Instance) and tp.type.replaced:
761-
replaced = tp.type.replaced
762-
if replaced.tuple_type:
763-
tp = replaced.tuple_type
764-
if replaced.typeddict_type:
765-
tp = replaced.typeddict_type
766-
return tp
767-
768726
def visit_callable_type(self, t: CallableType) -> None:
769727
t.ret_type.accept(self)
770728
for arg_type in t.arg_types:

test-data/unit/check-newtype.test

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,10 @@ d: object
360360
if isinstance(d, T): # E: Cannot use isinstance() with a NewType type
361361
reveal_type(d) # E: Revealed type is '__main__.T'
362362
[builtins fixtures/isinstancelist.pyi]
363+
364+
[case testInvalidNewTypeCrash]
365+
from typing import List, NewType, Union
366+
N = NewType('N', XXX) # E: Argument 2 to NewType(...) must be subclassable (got "Any") \
367+
# E: Name 'XXX' is not defined
368+
x: List[Union[N, int]] # E: Invalid type "__main__.N"
369+
[builtins fixtures/list.pyi]

test-data/unit/check-typeddict.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1333,7 +1333,7 @@ T = TypeVar('T', bound='M')
13331333
class G(Generic[T]):
13341334
x: T
13351335

1336-
yb: G[int] # E: Type argument "builtins.int" of "G" must be a subtype of "TypedDict({'x': builtins.int}, fallback=typing.Mapping[builtins.str, builtins.object])"
1336+
yb: G[int] # E: Type argument "builtins.int" of "G" must be a subtype of "TypedDict('__main__.M', {'x': builtins.int})"
13371337
yg: G[M]
13381338
z: int = G[M]().x['x']
13391339

test-data/unit/check-unions.test

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,3 +940,10 @@ x: Union[ExtremelyLongTypeNameWhichIsGenericSoWeCanUseItMultipleTimes[int],
940940
def takes_int(arg: int) -> None: pass
941941

942942
takes_int(x) # E: Argument 1 to "takes_int" has incompatible type <union: 6 items>; expected "int"
943+
944+
[case testRecursiveForwardReferenceInUnion]
945+
from typing import List, Union
946+
MYTYPE = List[Union[str, "MYTYPE"]]
947+
[builtins fixtures/list.pyi]
948+
[out]
949+
main:2: error: Recursive types not fully supported yet, nested types replaced with "Any"

0 commit comments

Comments
 (0)