diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 79454d029ee2..432d77fefc91 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1439,6 +1439,9 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type: return AnyType() elif isinstance(left_type, TypedDictType): return self.visit_typeddict_index_expr(left_type, e.index) + elif (isinstance(left_type, CallableType) + and left_type.is_type_obj() and left_type.type_object().is_enum): + return self.visit_enum_index_expr(left_type.type_object(), e.index, e) else: result, method_type = self.check_op('__getitem__', left_type, e.index, e) e.method_type = method_type @@ -1497,6 +1500,16 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) return AnyType() return item_type + def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression, + context: Context) -> Type: + string_type = self.named_type('builtins.str') # type: Type + if self.chk.options.python_version[0] < 3: + string_type = UnionType.make_union([string_type, + self.named_type('builtins.unicode')]) + self.chk.check_subtype(self.accept(index), string_type, context, + "Enum index should be a string", "actual index type") + return Instance(enum_type, []) + def visit_cast_expr(self, expr: CastExpr) -> Type: """Type check a cast expression.""" source_type = self.accept(expr.expr, context=AnyType()) diff --git a/mypy/semanal.py b/mypy/semanal.py index 71a8323be292..8c5146e8a056 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2455,7 +2455,11 @@ def visit_unary_expr(self, expr: UnaryExpr) -> None: def visit_index_expr(self, expr: IndexExpr) -> None: expr.base.accept(self) - if isinstance(expr.base, RefExpr) and expr.base.kind == TYPE_ALIAS: + if (isinstance(expr.base, RefExpr) + and isinstance(expr.base.node, TypeInfo) + and expr.base.node.is_enum): + expr.index.accept(self) + elif isinstance(expr.base, RefExpr) and expr.base.kind == TYPE_ALIAS: # Special form -- subscripting a generic type alias. # Perform the type substitution and create a new alias. res = analyze_type_alias(expr, diff --git a/test-data/unit/pythoneval-enum.test b/test-data/unit/pythoneval-enum.test index 3ae2df55f120..dfbd8f73cef0 100644 --- a/test-data/unit/pythoneval-enum.test +++ b/test-data/unit/pythoneval-enum.test @@ -119,7 +119,6 @@ def takes_some_ext_int_enum(s: SomeExtIntEnum): pass takes_some_ext_int_enum(SomeExtIntEnum.x) - [case testNamedTupleEnum] from typing import NamedTuple from enum import Enum @@ -132,3 +131,29 @@ class E(N, Enum): def f(x: E) -> None: pass f(E.X) + +[case testEnumCall] +from enum import IntEnum +class E(IntEnum): + a = 1 +x = None # type: int +reveal_type(E(x)) +[out] +_program.py:5: error: Revealed type is '_testEnumCall.E' + +[case testEnumIndex] +from enum import IntEnum +class E(IntEnum): + a = 1 +s = None # type: str +reveal_type(E[s]) +[out] +_program.py:5: error: Revealed type is '_testEnumIndex.E' + +[case testEnumIndexError] +from enum import IntEnum +class E(IntEnum): + a = 1 +E[1] +[out] +_program.py:4: error: Enum index should be a string (actual index type "int")