Skip to content

Commit 25b6e79

Browse files
committed
Fix mypyc crash with enum type aliases
mypyc was crashing because it couldn't find the type in the type map. This PR adds a generic AnyType to the type map if an expression isn't in the map already. Tried actually changing mypy to accept these type alias expressions, but ran into problems with nested type aliases where the inner one doesn't have the "analyzed" value and ending up with wrong results. fixes mypyc/mypyc#1064
1 parent d87f0b2 commit 25b6e79

File tree

4 files changed

+44
-1
lines changed

4 files changed

+44
-1
lines changed

mypyc/irbuild/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def build_ir(
7373

7474
for module in modules:
7575
# First pass to determine free symbols.
76-
pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove)
76+
pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove, types)
7777
module.accept(pbv)
7878

7979
# Construct and configure builder objects (cyclic runtime dependency).

mypyc/irbuild/missingtypevisitor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
from mypy.nodes import Expression, Node
4+
from mypy.traverser import ExtendedTraverserVisitor
5+
from mypy.types import Type, AnyType, TypeOfAny
6+
7+
8+
class MissingTypesVisitor(ExtendedTraverserVisitor):
9+
"""AST visitor that can be used to add any missing types as a generic AnyType."""
10+
11+
def __init__(self, types: dict[Expression, Type]) -> None:
12+
super().__init__()
13+
self.types: dict[Expression, Type] = types
14+
15+
def visit(self, o: Node) -> bool:
16+
if isinstance(o, Expression) and o not in self.types:
17+
self.types[o] = AnyType(TypeOfAny.special_form)
18+
19+
# If returns True, will continue to nested nodes.
20+
return True

mypyc/irbuild/prebuildvisitor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from __future__ import annotations
22

33
from mypy.nodes import (
4+
AssignmentStmt,
45
Block,
56
Decorator,
67
Expression,
78
FuncDef,
89
FuncItem,
910
Import,
11+
IndexExpr,
1012
LambdaExpr,
1113
MemberExpr,
1214
MypyFile,
@@ -16,7 +18,9 @@
1618
Var,
1719
)
1820
from mypy.traverser import ExtendedTraverserVisitor
21+
from mypy.types import Type
1922
from mypyc.errors import Errors
23+
from mypyc.irbuild.missingtypevisitor import MissingTypesVisitor
2024

2125

2226
class PreBuildVisitor(ExtendedTraverserVisitor):
@@ -39,6 +43,7 @@ def __init__(
3943
errors: Errors,
4044
current_file: MypyFile,
4145
decorators_to_remove: dict[FuncDef, list[int]],
46+
types: dict[Expression, Type],
4247
) -> None:
4348
super().__init__()
4449
# Dict from a function to symbols defined directly in the
@@ -82,11 +87,20 @@ def __init__(
8287

8388
self.current_file: MypyFile = current_file
8489

90+
self.missing_types_visitor = MissingTypesVisitor(types)
91+
8592
def visit(self, o: Node) -> bool:
8693
if not isinstance(o, Import):
8794
self._current_import_group = None
8895
return True
8996

97+
def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
98+
# These are cases where mypy may not have types for certain expressions,
99+
# but mypyc needs some form type to exist.
100+
if isinstance(stmt.rvalue, IndexExpr) and stmt.rvalue.analyzed:
101+
stmt.rvalue.accept(self.missing_types_visitor)
102+
return super().visit_assignment_stmt(stmt)
103+
90104
def visit_block(self, block: Block) -> None:
91105
self._current_import_group = None
92106
super().visit_block(block)

mypyc/test-data/irbuild-classes.test

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,3 +1335,12 @@ def outer():
13351335
if True:
13361336
class OtherInner: # E: Nested class definitions not supported
13371337
pass
1338+
1339+
[case testEnumClassAlias]
1340+
from enum import Enum
1341+
from typing import Literal
1342+
1343+
class SomeEnum(Enum):
1344+
AVALUE = "a"
1345+
1346+
ALIAS = Literal[SomeEnum.AVALUE]

0 commit comments

Comments
 (0)