From 7518b43345d65e7a366740669f009403301cb4ab Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 31 Jul 2022 18:37:40 +0100 Subject: [PATCH 1/7] Enable recursive type aliases behind a flag --- mypy/checker.py | 11 +++- mypy/checkexpr.py | 14 ++++ mypy/expandtype.py | 3 +- mypy/infer.py | 3 +- mypy/main.py | 5 ++ mypy/messages.py | 13 +++- mypy/options.py | 2 + mypy/sametypes.py | 4 ++ mypy/semanal.py | 76 +++++++++++++++++++--- mypy/solve.py | 22 ++++++- mypy/subtypes.py | 4 +- mypy/typeanal.py | 2 +- mypy/typeops.py | 12 +--- mypy/types.py | 47 +++++++++++-- test-data/unit/check-newsemanal.test | 5 ++ test-data/unit/check-recursive-types.test | 59 +++++++++++++++++ test-data/unit/fixtures/isinstancelist.pyi | 2 + 17 files changed, 248 insertions(+), 36 deletions(-) create mode 100644 test-data/unit/check-recursive-types.test diff --git a/mypy/checker.py b/mypy/checker.py index 00e104d8bcf3..f0052df1e648 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3662,11 +3662,13 @@ def check_simple_assignment( # '...' is always a valid initializer in a stub. return AnyType(TypeOfAny.special_form) else: + orig_lvalue = lvalue_type lvalue_type = get_proper_type(lvalue_type) always_allow_any = lvalue_type is not None and not isinstance(lvalue_type, AnyType) rvalue_type = self.expr_checker.accept( rvalue, lvalue_type, always_allow_any=always_allow_any ) + orig_rvalue = rvalue_type rvalue_type = get_proper_type(rvalue_type) if isinstance(rvalue_type, DeletedType): self.msg.deleted_as_rvalue(rvalue_type, context) @@ -3674,8 +3676,9 @@ def check_simple_assignment( self.msg.deleted_as_lvalue(lvalue_type, context) elif lvalue_type: self.check_subtype( - rvalue_type, - lvalue_type, + # Preserve original aliases for error messages when possible. + orig_rvalue, + orig_lvalue, context, msg, f"{rvalue_name} has type", @@ -5568,7 +5571,9 @@ def check_subtype( code = msg.code else: msg_text = msg + orig_subtype = subtype subtype = get_proper_type(subtype) + orig_supertype = supertype supertype = get_proper_type(supertype) if self.msg.try_report_long_tuple_assignment_error( subtype, supertype, context, msg_text, subtype_label, supertype_label, code=code @@ -5580,7 +5585,7 @@ def check_subtype( note_msg = "" notes: List[str] = [] if subtype_label is not None or supertype_label is not None: - subtype_str, supertype_str = format_type_distinctly(subtype, supertype) + subtype_str, supertype_str = format_type_distinctly(orig_subtype, orig_supertype) if subtype_label is not None: extra_info.append(subtype_label + " " + subtype_str) if supertype_label is not None: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 2b947cdc8e32..b68aa5336a37 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -138,6 +138,7 @@ StarType, TupleType, Type, + TypeAliasType, TypedDictType, TypeOfAny, TypeType, @@ -147,6 +148,7 @@ flatten_nested_unions, get_proper_type, get_proper_types, + has_recursive_types, is_generic_instance, is_named_instance, is_optional, @@ -1534,6 +1536,17 @@ def infer_function_type_arguments( else: pass1_args.append(arg) + # This is a hack to better support inference for recursive types. + # When the outer context for a function call is known to be recursive, + # we solve type constraints inferred from arguments using unions instead + # of joins. This is a bit arbitrary, but in practice it works for most + # cases. A cleaner alternative would be to switch to single bin type + # inference, but this is a lot of work. + ctx = self.type_context[-1] + if ctx and has_recursive_types(ctx): + infer_unions = True + else: + infer_unions = False inferred_args = infer_function_type_arguments( callee_type, pass1_args, @@ -1541,6 +1554,7 @@ def infer_function_type_arguments( formal_to_actual, context=self.argument_infer_context(), strict=self.chk.in_checked_function(), + infer_unions=infer_unions, ) if 2 in arg_pass_nums: diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 4515a137ced2..5b7148be0c87 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -293,7 +293,8 @@ def expand_types_with_unpack( else: items.extend(unpacked_items) else: - items.append(proper_item.accept(self)) + # Must preserve original aliases when possible. + items.append(item.accept(self)) return items def visit_tuple_type(self, t: TupleType) -> Type: diff --git a/mypy/infer.py b/mypy/infer.py index d3ad0bc19f9b..1c00d2904702 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -34,6 +34,7 @@ def infer_function_type_arguments( formal_to_actual: List[List[int]], context: ArgumentInferContext, strict: bool = True, + infer_unions: bool = False, ) -> List[Optional[Type]]: """Infer the type arguments of a generic function. @@ -55,7 +56,7 @@ def infer_function_type_arguments( # Solve constraints. type_vars = callee_type.type_var_ids() - return solve_constraints(type_vars, constraints, strict) + return solve_constraints(type_vars, constraints, strict, infer_unions=infer_unions) def infer_type_arguments( diff --git a/mypy/main.py b/mypy/main.py index 85a1eb0765eb..de58cff404ea 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -977,6 +977,11 @@ def add_invertible_flag( dest="custom_typing_module", help="Use a custom typing module", ) + internals_group.add_argument( + "--enable-recursive-aliases", + action="store_true", + help="Experimental support for recursive type aliases", + ) internals_group.add_argument( "--custom-typeshed-dir", metavar="DIR", help="Use the custom typeshed in DIR" ) diff --git a/mypy/messages.py b/mypy/messages.py index 3a38f91253a4..128bb36c984a 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -87,6 +87,7 @@ ProperType, TupleType, Type, + TypeAliasType, TypedDictType, TypeOfAny, TypeType, @@ -2128,7 +2129,17 @@ def format_literal_value(typ: LiteralType) -> str: else: return typ.value_repr() - # TODO: show type alias names in errors. + if isinstance(typ, TypeAliasType) and typ.is_recursive: + # TODO: find balance here, str(typ) doesn't support custom verbosity, and may be + # too verbose for user messages, OTOH it nicely shows structure of recursive types. + if verbosity < 2: + type_str = typ.alias.name if typ.alias else "" + if typ.args: + type_str += f"[{format_list(typ.args)}]" + return type_str + return str(typ) + + # TODO: always mention type alias names in errors. typ = get_proper_type(typ) if isinstance(typ, Instance): diff --git a/mypy/options.py b/mypy/options.py index 860c296cfbb0..dd5afdc42a25 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -315,6 +315,8 @@ def __init__(self) -> None: # skip most errors after this many messages have been reported. # -1 means unlimited. self.many_errors_threshold = defaults.MANY_ERRORS_THRESHOLD + # Enable recursive type aliases (currently experimental) + self.enable_recursive_aliases = False # To avoid breaking plugin compatibility, keep providing new_semantic_analyzer @property diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 691af147d98f..33f2cdf7aa16 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -33,6 +33,10 @@ def is_same_type(left: Type, right: Type) -> bool: """Is 'left' the same type as 'right'?""" + if isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType): + if left.is_recursive and right.is_recursive: + return left.alias == right.alias and left.args == right.args + left = get_proper_type(left) right = get_proper_type(right) diff --git a/mypy/semanal.py b/mypy/semanal.py index 928b084d981b..f38e1b236b45 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -450,6 +450,14 @@ def __init__( # current SCC or top-level function. self.deferral_debug_context: List[Tuple[str, int]] = [] + # This is needed to properly support recursive type aliases. The problem is that + # Foo[Bar] could mean three things depending on context: a target for type alias, + # a normal index expression (including enum index), or a type application. + # The latter is particularly problematic as it can falsely create incomplete + # refs while analysing rvalues of type aliases. To avoid this we first analyse + # rvalues while temporarily setting this to True. + self.basic_type_applications = False + # mypyc doesn't properly handle implementing an abstractproperty # with a regular attribute so we make them properties @property @@ -2286,7 +2294,14 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: return tag = self.track_incomplete_refs() - s.rvalue.accept(self) + + # Here we have a chicken and egg problem: at this stage we can't call + # can_be_type_alias(), because we have not enough information about rvalue. + # But we can't use a full visit because it may emit extra incomplete refs (namely + # when analysing any type applications there) thus preventing the further analysis. + # To break the tie, we first analyse rvalue partially, if it can be a type alias. + with self.basic_type_applications_set(s): + s.rvalue.accept(self) if self.found_incomplete_ref(tag) or self.should_wait_rhs(s.rvalue): # Initializer couldn't be fully analyzed. Defer the current node and give up. # Make sure that if we skip the definition of some local names, they can't be @@ -2326,6 +2341,10 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: s.is_alias_def = False # OK, this is a regular assignment, perform the necessary analysis steps. + if self.can_possibly_be_indexed_alias(s): + # Do a full visit if this is not a type alias after all. This will give + # consistent error messages, and it is safe as semantic analyzer is idempotent. + s.rvalue.accept(self) s.is_final_def = self.unwrap_final(s) self.analyze_lvalues(s) self.check_final_implicit_def(s) @@ -2432,6 +2451,32 @@ def can_be_type_alias(self, rv: Expression, allow_none: bool = False) -> bool: return True return False + @contextmanager + def basic_type_applications_set(self, s: AssignmentStmt) -> Iterator[None]: + old = self.basic_type_applications + self.basic_type_applications = self.can_possibly_be_indexed_alias(s) + try: + yield + finally: + self.basic_type_applications = old + + def can_possibly_be_indexed_alias(self, s: AssignmentStmt) -> bool: + """Like can_be_type_alias(), but simpler and doesn't require analyzed rvalue. + + Instead, use lvalues/annotations structure to figure out whether this can + potentially be a type alias definition. + """ + if len(s.lvalues) > 1: + return False + if not isinstance(s.lvalues[0], NameExpr): + return False + if s.unanalyzed_type is not None and not self.is_pep_613(s): + return False + if not isinstance(s.rvalue, IndexExpr): + return False + # Something that looks like Foo = Bar[Baz, ...] + return True + def is_type_ref(self, rv: Expression, bare: bool = False) -> bool: """Does this expression refer to a type? @@ -2908,6 +2953,13 @@ def analyze_alias( qualified_tvars = [] return typ, alias_tvars, depends_on, qualified_tvars + def is_pep_613(self, s: AssignmentStmt) -> bool: + if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType): + lookup = self.lookup_qualified(s.unanalyzed_type.name, s, suppress_errors=True) + if lookup and lookup.fullname in TYPE_ALIAS_NAMES: + return True + return False + def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: """Check if assignment creates a type alias and set it up as needed. @@ -2922,11 +2974,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: # First rule: Only simple assignments like Alias = ... create aliases. return False - pep_613 = False - if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType): - lookup = self.lookup_qualified(s.unanalyzed_type.name, s, suppress_errors=True) - if lookup and lookup.fullname in TYPE_ALIAS_NAMES: - pep_613 = True + pep_613 = self.is_pep_613(s) if not pep_613 and s.unanalyzed_type is not None: # Second rule: Explicit type (cls: Type[A] = A) always creates variable, not alias. # unless using PEP 613 `cls: TypeAlias = A` @@ -2990,9 +3038,16 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: ) if not res: return False - # TODO: Maybe we only need to reject top-level placeholders, similar - # to base classes. - if self.found_incomplete_ref(tag) or has_placeholder(res): + if self.options.enable_recursive_aliases: + # Only marking incomplete for top-level placeholders makes recursive aliases like + # `A = Sequence[str | A]` valid here, similar to how we treat base classes in class + # definitions, allowing `class str(Sequence[str]): ...` + incomplete_target = isinstance(res, ProperType) and isinstance( + res, PlaceholderType + ) + else: + incomplete_target = has_placeholder(res) + if self.found_incomplete_ref(tag) or incomplete_target: # Since we have got here, we know this must be a type alias (incomplete refs # may appear in nested positions), therefore use becomes_typeinfo=True. self.mark_incomplete(lvalue.name, rvalue, becomes_typeinfo=True) @@ -4499,6 +4554,9 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]] self.analyze_type_expr(index) if self.found_incomplete_ref(tag): return None + if self.basic_type_applications: + # Postpone the rest until we are sure this is not a r.h.s. of a type alias. + return None types: List[Type] = [] if isinstance(index, TupleExpr): items = index.items diff --git a/mypy/solve.py b/mypy/solve.py index 2c3a5b5e3300..4a284a3b8d0b 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -7,11 +7,22 @@ from mypy.join import join_types from mypy.meet import meet_types from mypy.subtypes import is_subtype -from mypy.types import AnyType, Type, TypeOfAny, TypeVarId, UninhabitedType, get_proper_type +from mypy.types import ( + AnyType, + Type, + TypeOfAny, + TypeVarId, + UninhabitedType, + UnionType, + get_proper_type, +) def solve_constraints( - vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True + vars: List[TypeVarId], + constraints: List[Constraint], + strict: bool = True, + infer_unions: bool = False, ) -> List[Optional[Type]]: """Solve type constraints. @@ -43,7 +54,12 @@ def solve_constraints( if bottom is None: bottom = c.target else: - bottom = join_types(bottom, c.target) + if infer_unions: + # This deviates from the general mypy semantics because + # recursive types are union-heavy in 95% of cases. + bottom = UnionType.make_union([bottom, c.target]) + else: + bottom = join_types(bottom, c.target) else: if top is None: top = c.target diff --git a/mypy/subtypes.py b/mypy/subtypes.py index c5d2cc5e98c1..7ef702e8493d 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -105,13 +105,15 @@ def is_subtype( if TypeState.is_assumed_subtype(left, right): return True if ( + # TODO: recursive instances like `class str(Sequence[str])` can also cause + # issues, so we also need to include them in the assumptions stack isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType) and left.is_recursive and right.is_recursive ): # This case requires special care because it may cause infinite recursion. - # Our view on recursive types is known under a fancy name of equirecursive mu-types. + # Our view on recursive types is known under a fancy name of iso-recursive mu-types. # Roughly this means that a recursive type is defined as an alias where right hand side # can refer to the type as a whole, for example: # A = Union[int, Tuple[A, ...]] diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d6615d4f4c9e..eb0ba7d03a43 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -362,7 +362,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) if ( isinstance(res, Instance) # type: ignore[misc] and len(res.args) != len(res.type.type_vars) - and not self.defining_alias + and (not self.defining_alias or self.nesting_level) ): fix_instance( res, diff --git a/mypy/typeops.py b/mypy/typeops.py index 7fc012fd3c78..354e51aff7ff 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -55,6 +55,7 @@ UninhabitedType, UnionType, UnpackType, + flatten_nested_unions, get_proper_type, get_proper_types, ) @@ -436,17 +437,8 @@ def make_simplified_union( back into a sum type. Set it to False when called by try_expanding_sum_type_ to_union(). """ - items = get_proper_types(items) - # Step 1: expand all nested unions - while any(isinstance(typ, UnionType) for typ in items): - all_items: List[ProperType] = [] - for typ in items: - if isinstance(typ, UnionType): - all_items.extend(get_proper_types(typ.items)) - else: - all_items.append(typ) - items = all_items + items = cast(List[ProperType], flatten_nested_unions(items, handle_type_alias_type=True)) # Step 2: remove redundant unions simplified_set = _remove_redundant_union_items(items, keep_erased) diff --git a/mypy/types.py b/mypy/types.py index ad39edee4112..43276bf6d628 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3035,7 +3035,29 @@ def list_str(self, a: Iterable[Type]) -> str: return ", ".join(res) -class UnrollAliasVisitor(TypeTranslator): +class TrivialSyntheticTypeTranslator(TypeTranslator, SyntheticTypeVisitor[Type]): + """A base class for type translators that need to be run during semantic analysis.""" + + def visit_placeholder_type(self, t: PlaceholderType) -> Type: + return t + + def visit_callable_argument(self, t: CallableArgument) -> Type: + return t + + def visit_ellipsis_type(self, t: EllipsisType) -> Type: + return t + + def visit_raw_expression_type(self, t: RawExpressionType) -> Type: + return t + + def visit_star_type(self, t: StarType) -> Type: + return t + + def visit_type_list(self, t: TypeList) -> Type: + return t + + +class UnrollAliasVisitor(TrivialSyntheticTypeTranslator): def __init__(self, initial_aliases: Set[TypeAliasType]) -> None: self.recursed = False self.initial_aliases = initial_aliases @@ -3074,7 +3096,7 @@ def is_named_instance(t: Type, fullnames: Union[str, Tuple[str, ...]]) -> bool: return isinstance(t, Instance) and t.type.fullname in fullnames -class InstantiateAliasVisitor(TypeTranslator): +class InstantiateAliasVisitor(TrivialSyntheticTypeTranslator): def __init__(self, vars: List[str], subs: List[Type]) -> None: self.replacements = {v: s for (v, s) in zip(vars, subs)} @@ -3122,6 +3144,19 @@ def has_type_vars(typ: Type) -> bool: return typ.accept(HasTypeVars()) +class HasRecursiveType(TypeQuery[bool]): + def __init__(self) -> None: + super().__init__(any) + + def visit_type_alias_type(self, t: TypeAliasType) -> bool: + return t.is_recursive + + +def has_recursive_types(typ: Type) -> bool: + """Check if a type contains any recursive aliases (recursively).""" + return typ.accept(HasRecursiveType()) + + def flatten_nested_unions( types: Iterable[Type], handle_type_alias_type: bool = False ) -> List[Type]: @@ -3130,16 +3165,16 @@ def flatten_nested_unions( # if passed a "pathological" alias like A = Union[int, A] or similar. # TODO: ban such aliases in semantic analyzer. flat_items: List[Type] = [] - if handle_type_alias_type: - types = get_proper_types(types) # TODO: avoid duplicate types in unions (e.g. using hash) - for tp in types: + for t in types: + tp = get_proper_type(t) if handle_type_alias_type else t if isinstance(tp, ProperType) and isinstance(tp, UnionType): flat_items.extend( flatten_nested_unions(tp.items, handle_type_alias_type=handle_type_alias_type) ) else: - flat_items.append(tp) + # Must preserve original aliases when possible. + flat_items.append(t) return flat_items diff --git a/test-data/unit/check-newsemanal.test b/test-data/unit/check-newsemanal.test index 163805ab4bcb..bf612f95b3a2 100644 --- a/test-data/unit/check-newsemanal.test +++ b/test-data/unit/check-newsemanal.test @@ -3229,3 +3229,8 @@ class b: T = Union[Any] [builtins fixtures/tuple.pyi] + +[case testSelfReferentialSubscriptExpression] +x = x[1] # E: Cannot resolve name "x" (possible cyclic definition) +y = 1[y] # E: Value of type "int" is not indexable \ + # E: Cannot determine type of "y" diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test new file mode 100644 index 000000000000..b636cc7d8bba --- /dev/null +++ b/test-data/unit/check-recursive-types.test @@ -0,0 +1,59 @@ +[case testRecursiveAliasBasic] +# flags: --enable-recursive-aliases +from typing import Dict, List, Union, TypeVar, Sequence + +JSON = Union[str, List[JSON], Dict[str, JSON]] + +x: JSON = ["foo", {"bar": "baz"}] + +reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[...], builtins.dict[builtins.str, ...]]" +if isinstance(x, list): + x = x[0] + +class Bad: ... +x = ["foo", {"bar": [Bad()]}] # E: List item 0 has incompatible type "Bad"; expected "Union[str, List[JSON], Dict[str, JSON]]" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasBasicGenericSubtype] +# flags: --enable-recursive-aliases +from typing import Union, TypeVar, Sequence, List + +T = TypeVar("T") + +Nested = Sequence[Union[T, Nested[T]]] + +class Bad: ... +x: Nested[int] +y: Nested[Bad] +x = y # E: Incompatible types in assignment (expression has type "Nested[Bad]", variable has type "Nested[int]") + +NestedOther = Sequence[Union[T, Nested[T]]] + +xx: Nested[int] +yy: NestedOther[bool] +xx = yy # OK +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasBasicGenericInference] +# flags: --enable-recursive-aliases +from typing import Union, TypeVar, Sequence, List + +T = TypeVar("T") + +Nested = Sequence[Union[T, Nested[T]]] + +def flatten(arg: Nested[T]) -> List[T]: + res: List[T] = [] + for item in arg: + if isinstance(item, Sequence): + res.extend(flatten(item)) + else: + res.append(item) + return res + +reveal_type(flatten([1, [2, [3]]])) # N: Revealed type is "builtins.list[builtins.int]" + +class Bad: ... +x: Nested[int] = [1, [2, [3]]] +x = [1, [Bad()]] # E: List item 0 has incompatible type "Bad"; expected "Union[int, Nested[int]]" +[builtins fixtures/isinstancelist.pyi] diff --git a/test-data/unit/fixtures/isinstancelist.pyi b/test-data/unit/fixtures/isinstancelist.pyi index 3865b6999ab0..0ee5258ff74b 100644 --- a/test-data/unit/fixtures/isinstancelist.pyi +++ b/test-data/unit/fixtures/isinstancelist.pyi @@ -41,6 +41,8 @@ class list(Sequence[T]): def __getitem__(self, x: int) -> T: pass def __add__(self, x: List[T]) -> T: pass def __contains__(self, item: object) -> bool: pass + def append(self, x: T) -> None: pass + def extend(self, x: Iterable[T]) -> None: pass class dict(Mapping[KT, VT]): @overload From 43616fab6e4283de7cc5064f8a1ec9e36223c2c3 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 31 Jul 2022 19:47:34 +0100 Subject: [PATCH 2/7] Fixes --- mypy/checker.py | 2 +- mypy/checkexpr.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index f0052df1e648..4ef5c50502e0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3678,7 +3678,7 @@ def check_simple_assignment( self.check_subtype( # Preserve original aliases for error messages when possible. orig_rvalue, - orig_lvalue, + orig_lvalue or lvalue_type, context, msg, f"{rvalue_name} has type", diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b68aa5336a37..043b8a81faa3 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -138,7 +138,6 @@ StarType, TupleType, Type, - TypeAliasType, TypedDictType, TypeOfAny, TypeType, From 056fe57c0cc9522b89385db5cf6c42ea3088795e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 31 Jul 2022 21:18:08 +0100 Subject: [PATCH 3/7] Fix wrong cast --- mypy/typeops.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index 354e51aff7ff..0db15cdb7bc3 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -438,19 +438,22 @@ def make_simplified_union( to_union(). """ # Step 1: expand all nested unions - items = cast(List[ProperType], flatten_nested_unions(items, handle_type_alias_type=True)) + items = flatten_nested_unions(items, handle_type_alias_type=True) # Step 2: remove redundant unions simplified_set = _remove_redundant_union_items(items, keep_erased) # Step 3: If more than one literal exists in the union, try to simplify - if contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1: + if ( + contract_literals + and sum(isinstance(get_proper_type(item), LiteralType) for item in simplified_set) > 1 + ): simplified_set = try_contracting_literals_in_union(simplified_set) - return UnionType.make_union(simplified_set, line, column) + return get_proper_type(UnionType.make_union(simplified_set, line, column)) -def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]: +def _remove_redundant_union_items(items: List[Type], keep_erased: bool) -> List[Type]: from mypy.subtypes import is_proper_subtype removed: Set[int] = set() @@ -461,10 +464,11 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> # different enum types as try_expanding_sum_type_to_union works recursively and will # trigger intermediate simplifications that would render the fast path useless for i, item in enumerate(items): + proper_item = get_proper_type(item) if i in removed: continue # Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169) - k = simple_literal_value_key(item) + k = simple_literal_value_key(proper_item) if k is not None: if k in seen: removed.add(i) @@ -485,6 +489,7 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> # Keep track of the truthiness info for deleted subtypes which can be relevant cbt = cbf = False for j, tj in enumerate(items): + proper_tj = get_proper_type(tj) if ( i == j # avoid further checks if this item was already marked redundant. @@ -495,11 +500,11 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> # However, if the current item is not a literal, it might plausibly be a # supertype of other literals in the union, so we must check them again. # This is an important optimization as is_proper_subtype is pretty expensive. - or (k is not None and is_simple_literal(tj)) + or (k is not None and is_simple_literal(proper_tj)) ): continue - # actual redundancy checks - if is_redundant_literal_instance(item, tj) and is_proper_subtype( # XXX? + # actual redundancy checks (XXX?) + if is_redundant_literal_instance(proper_item, proper_tj) and is_proper_subtype( tj, item, keep_erased_types=keep_erased ): # We found a redundant item in the union. From 6771dd46039ed67aa454aa8434acd4827c97d02c Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 31 Jul 2022 21:24:05 +0100 Subject: [PATCH 4/7] Fix self-check --- mypy/typeops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index 0db15cdb7bc3..f7b14c710cc2 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -441,7 +441,7 @@ def make_simplified_union( items = flatten_nested_unions(items, handle_type_alias_type=True) # Step 2: remove redundant unions - simplified_set = _remove_redundant_union_items(items, keep_erased) + simplified_set: Sequence[Type] = _remove_redundant_union_items(items, keep_erased) # Step 3: If more than one literal exists in the union, try to simplify if ( From 2dd986ec4184ada3772d447da065a93c1c9d9997 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 1 Aug 2022 01:52:36 +0100 Subject: [PATCH 5/7] Support new style aliases --- mypy/semanal.py | 10 +++++----- mypy/typeanal.py | 2 +- test-data/unit/check-recursive-types.test | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index f38e1b236b45..207548de0734 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2309,6 +2309,10 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for expr in names_modified_by_assignment(s): self.mark_incomplete(expr.name, expr) return + if self.can_possibly_be_indexed_alias(s): + # Now re-visit those rvalues that were we skipped type applications above. + # This should be safe as generally semantic analyzer is idempotent. + s.rvalue.accept(self) # The r.h.s. is now ready to be classified, first check if it is a special form: special_form = False @@ -2341,10 +2345,6 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: s.is_alias_def = False # OK, this is a regular assignment, perform the necessary analysis steps. - if self.can_possibly_be_indexed_alias(s): - # Do a full visit if this is not a type alias after all. This will give - # consistent error messages, and it is safe as semantic analyzer is idempotent. - s.rvalue.accept(self) s.is_final_def = self.unwrap_final(s) self.analyze_lvalues(s) self.check_final_implicit_def(s) @@ -2472,7 +2472,7 @@ def can_possibly_be_indexed_alias(self, s: AssignmentStmt) -> bool: return False if s.unanalyzed_type is not None and not self.is_pep_613(s): return False - if not isinstance(s.rvalue, IndexExpr): + if not isinstance(s.rvalue, (IndexExpr, OpExpr)): return False # Something that looks like Foo = Bar[Baz, ...] return True diff --git a/mypy/typeanal.py b/mypy/typeanal.py index eb0ba7d03a43..d6615d4f4c9e 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -362,7 +362,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) if ( isinstance(res, Instance) # type: ignore[misc] and len(res.args) != len(res.type.type_vars) - and (not self.defining_alias or self.nesting_level) + and not self.defining_alias ): fix_instance( res, diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index b636cc7d8bba..fe8b651521a8 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -57,3 +57,18 @@ class Bad: ... x: Nested[int] = [1, [2, [3]]] x = [1, [Bad()]] # E: List item 0 has incompatible type "Bad"; expected "Union[int, Nested[int]]" [builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasNewStyleSupported] +# flags: --enable-recursive-aliases +from test import A + +x: A +if isinstance(x, list): + reveal_type(x[0]) # N: Revealed type is "Union[builtins.int, builtins.list[Union[builtins.int, builtins.list[...]]]]" +else: + reveal_type(x) # N: Revealed type is "builtins.int" + +[file test.pyi] +A = int | list[A] + +[builtins fixtures/isinstancelist.pyi] From e7914bf60197e0b0c6f523c6c6671bdaa4650c63 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 1 Aug 2022 15:04:22 +0100 Subject: [PATCH 6/7] Double-test some existing tests --- mypy/semanal.py | 2 +- test-data/unit/check-recursive-types.test | 120 ++++++++++++++++++++++ 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index 207548de0734..de271dd80535 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -4555,7 +4555,7 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]] if self.found_incomplete_ref(tag): return None if self.basic_type_applications: - # Postpone the rest until we are sure this is not a r.h.s. of a type alias. + # Postpone the rest until we have more information (for r.h.s. of an assignment) return None types: List[Type] = [] if isinstance(index, TupleExpr): diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index fe8b651521a8..3c7db5827ed1 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -1,3 +1,5 @@ +-- Tests checking that basic functionality works + [case testRecursiveAliasBasic] # flags: --enable-recursive-aliases from typing import Dict, List, Union, TypeVar, Sequence @@ -70,5 +72,123 @@ else: [file test.pyi] A = int | list[A] +[builtins fixtures/isinstancelist.pyi] + +-- Tests duplicating some existing tests with recursive aliases enabled + +[case testRecursiveAliasesMutual] +# flags: --enable-recursive-aliases +from typing import Type, Callable, Union + +A = Union[B, int] +B = Callable[[C], int] +C = Type[A] +x: A +reveal_type(x) # N: Revealed type is "Union[def (Union[Type[def (...) -> builtins.int], Type[builtins.int]]) -> builtins.int, builtins.int]" + +[case testRecursiveAliasesProhibited-skip] +# flags: --enable-recursive-aliases +from typing import Type, Callable, Union + +A = Union[B, int] +B = Union[A, int] +C = Type[C] + +[case testRecursiveAliasImported] +# flags: --enable-recursive-aliases +import lib +x: lib.A +reveal_type(x) # N: Revealed type is "builtins.list[builtins.list[...]]" + +[file lib.pyi] +from typing import List +from other import B +A = List[B] + +[file other.pyi] +from typing import List +from lib import A +B = List[A] +[builtins fixtures/list.pyi] + +[case testRecursiveAliasViaBaseClass] +# flags: --enable-recursive-aliases +from typing import List +x: B +B = List[C] +class C(B): pass + +reveal_type(x) # N: Revealed type is "builtins.list[__main__.C]" +reveal_type(x[0][0]) # N: Revealed type is "__main__.C" +[builtins fixtures/list.pyi] + +[case testRecursiveAliasViaBaseClass2] +# flags: --enable-recursive-aliases +from typing import NewType, List + +x: D +reveal_type(x[0][0]) # N: Revealed type is "__main__.C" + +D = List[C] +C = NewType('C', B) + +class B(D): + pass +[builtins fixtures/list.pyi] + +[case testRecursiveAliasViaBaseClass3] +# flags: --enable-recursive-aliases +from typing import List, Generic, TypeVar, NamedTuple +T = TypeVar('T') + +class C(A, B): + pass +class G(Generic[T]): pass +A = G[C] +class B(NamedTuple): + x: int + +y: C +reveal_type(y.x) # N: Revealed type is "builtins.int" +reveal_type(y[0]) # N: Revealed type is "builtins.int" +x: A +reveal_type(x) # N: Revealed type is "__main__.G[Tuple[builtins.int, fallback=__main__.C]]" +[builtins fixtures/list.pyi] + +[case testRecursiveAliasViaBaseClassImported] +# flags: --enable-recursive-aliases +import a +[file a.py] +from typing import List +from b import D + +def f(x: B) -> List[B]: ... +B = List[C] +class C(B): pass + +[file b.py] +from a import f +class D: ... +reveal_type(f) # N: Revealed type is "def (x: builtins.list[a.C]) -> builtins.list[builtins.list[a.C]]" +[builtins fixtures/list.pyi] + +[case testRecursiveAliasViaNamedTuple] +# flags: --enable-recursive-aliases +from typing import List, NamedTuple, Union + +Exp = Union['A', 'B'] +class A(NamedTuple('A', [('attr', List[Exp])])): pass +class B(NamedTuple('B', [('val', object)])): pass + +def my_eval(exp: Exp) -> int: + reveal_type(exp) # N: Revealed type is "Union[Tuple[builtins.list[...], fallback=__main__.A], Tuple[builtins.object, fallback=__main__.B]]" + if isinstance(exp, A): + my_eval(exp[0][0]) + return my_eval(exp.attr[0]) + if isinstance(exp, B): + return exp.val # E: Incompatible return value type (got "object", expected "int") + return 0 + +my_eval(A([B(1), B(2)])) [builtins fixtures/isinstancelist.pyi] From 3664c56c8da0321408780f5351cde220ddf6d4dd Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 3 Aug 2022 10:56:17 +0100 Subject: [PATCH 7/7] Add comments and re-order methods more logically --- mypy/semanal.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index de271dd80535..a592b4fdde9e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2309,7 +2309,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for expr in names_modified_by_assignment(s): self.mark_incomplete(expr.name, expr) return - if self.can_possibly_be_indexed_alias(s): + if self.can_possibly_be_index_alias(s): # Now re-visit those rvalues that were we skipped type applications above. # This should be safe as generally semantic analyzer is idempotent. s.rvalue.accept(self) @@ -2451,20 +2451,13 @@ def can_be_type_alias(self, rv: Expression, allow_none: bool = False) -> bool: return True return False - @contextmanager - def basic_type_applications_set(self, s: AssignmentStmt) -> Iterator[None]: - old = self.basic_type_applications - self.basic_type_applications = self.can_possibly_be_indexed_alias(s) - try: - yield - finally: - self.basic_type_applications = old - - def can_possibly_be_indexed_alias(self, s: AssignmentStmt) -> bool: + def can_possibly_be_index_alias(self, s: AssignmentStmt) -> bool: """Like can_be_type_alias(), but simpler and doesn't require analyzed rvalue. Instead, use lvalues/annotations structure to figure out whether this can - potentially be a type alias definition. + potentially be a type alias definition. Another difference from above function + is that we are only interested IndexExpr and OpExpr rvalues, since only those + can be potentially recursive (things like `A = A` are never valid). """ if len(s.lvalues) > 1: return False @@ -2477,6 +2470,17 @@ def can_possibly_be_indexed_alias(self, s: AssignmentStmt) -> bool: # Something that looks like Foo = Bar[Baz, ...] return True + @contextmanager + def basic_type_applications_set(self, s: AssignmentStmt) -> Iterator[None]: + old = self.basic_type_applications + # As an optimization, only use the double visit logic if this + # can possibly be a recursive type alias. + self.basic_type_applications = self.can_possibly_be_index_alias(s) + try: + yield + finally: + self.basic_type_applications = old + def is_type_ref(self, rv: Expression, bare: bool = False) -> bool: """Does this expression refer to a type?