From 5543b5a05b372784c01d44b78dd92d2336877414 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 5 Feb 2017 17:08:57 -0800 Subject: [PATCH 1/2] Allow E[] where E is an Enum type. Fixes #1381. --- mypy/checkexpr.py | 13 +++++++++++++ mypy/semanal.py | 6 +++++- test-data/unit/pythoneval-enum.test | 19 ++++++++++++++++++- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 79454d029ee2..a6ccfb311d19 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1439,6 +1439,9 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type: return AnyType() elif isinstance(left_type, TypedDictType): return self.visit_typeddict_index_expr(left_type, e.index) + elif (isinstance(left_type, CallableType) + and left_type.is_type_obj() and left_type.type_object()): + return self.visit_enum_index_expr(left_type.type_object(), e.index, e) else: result, method_type = self.check_op('__getitem__', left_type, e.index, e) e.method_type = method_type @@ -1497,6 +1500,16 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) return AnyType() return item_type + def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression, + context: Context) -> Type: + string_type = self.named_type('builtins.str') # type: Type + if self.chk.options.python_version[0] < 3: + string_type = UnionType.make_union([string_type, + self.named_type('builtins.unicode')]) + self.chk.check_subtype(self.accept(index), string_type, context, + "Enum index should be a string", "actual index type") + return Instance(enum_type, []) + def visit_cast_expr(self, expr: CastExpr) -> Type: """Type check a cast expression.""" source_type = self.accept(expr.expr, context=AnyType()) diff --git a/mypy/semanal.py b/mypy/semanal.py index 71a8323be292..8c5146e8a056 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2455,7 +2455,11 @@ def visit_unary_expr(self, expr: UnaryExpr) -> None: def visit_index_expr(self, expr: IndexExpr) -> None: expr.base.accept(self) - if isinstance(expr.base, RefExpr) and expr.base.kind == TYPE_ALIAS: + if (isinstance(expr.base, RefExpr) + and isinstance(expr.base.node, TypeInfo) + and expr.base.node.is_enum): + expr.index.accept(self) + elif isinstance(expr.base, RefExpr) and expr.base.kind == TYPE_ALIAS: # Special form -- subscripting a generic type alias. # Perform the type substitution and create a new alias. res = analyze_type_alias(expr, diff --git a/test-data/unit/pythoneval-enum.test b/test-data/unit/pythoneval-enum.test index 3ae2df55f120..b0e7a7dd1986 100644 --- a/test-data/unit/pythoneval-enum.test +++ b/test-data/unit/pythoneval-enum.test @@ -119,7 +119,6 @@ def takes_some_ext_int_enum(s: SomeExtIntEnum): pass takes_some_ext_int_enum(SomeExtIntEnum.x) - [case testNamedTupleEnum] from typing import NamedTuple from enum import Enum @@ -132,3 +131,21 @@ class E(N, Enum): def f(x: E) -> None: pass f(E.X) + +[case testEnumCall] +from enum import IntEnum +class E(IntEnum): + a = 1 +x = None # type: int +reveal_type(E(x)) +[out] +_program.py:5: error: Revealed type is '_testEnumCall.E' + +[case testEnumIndex] +from enum import IntEnum +class E(IntEnum): + a = 1 +s = None # type: str +reveal_type(E[s]) +[out] +_program.py:5: error: Revealed type is '_testEnumIndex.E' From d82937691e3f6c6dd5efa8218bd07dd830905e9b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 6 Feb 2017 09:40:04 -0800 Subject: [PATCH 2/2] Respond to review by @ilevkivskyi --- mypy/checkexpr.py | 2 +- test-data/unit/pythoneval-enum.test | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a6ccfb311d19..432d77fefc91 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1440,7 +1440,7 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type: elif isinstance(left_type, TypedDictType): return self.visit_typeddict_index_expr(left_type, e.index) elif (isinstance(left_type, CallableType) - and left_type.is_type_obj() and left_type.type_object()): + and left_type.is_type_obj() and left_type.type_object().is_enum): return self.visit_enum_index_expr(left_type.type_object(), e.index, e) else: result, method_type = self.check_op('__getitem__', left_type, e.index, e) diff --git a/test-data/unit/pythoneval-enum.test b/test-data/unit/pythoneval-enum.test index b0e7a7dd1986..dfbd8f73cef0 100644 --- a/test-data/unit/pythoneval-enum.test +++ b/test-data/unit/pythoneval-enum.test @@ -149,3 +149,11 @@ s = None # type: str reveal_type(E[s]) [out] _program.py:5: error: Revealed type is '_testEnumIndex.E' + +[case testEnumIndexError] +from enum import IntEnum +class E(IntEnum): + a = 1 +E[1] +[out] +_program.py:4: error: Enum index should be a string (actual index type "int")