Skip to content

Support functional API for Enum. #2805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 31, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt,
WhileStmt, OperatorAssignmentStmt, WithStmt, AssertStmt,
RaiseStmt, TryStmt, ForStmt, DelStmt, CallExpr, IntExpr, StrExpr,
UnicodeExpr, OpExpr, UnaryExpr, LambdaExpr, TempNode, SymbolTableNode,
Context, Decorator, PrintStmt, LITERAL_TYPE, BreakStmt, PassStmt, ContinueStmt,
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, ImportFrom, ImportAll, ImportBase,
ARG_POS, CONTRAVARIANT, COVARIANT, ExecStmt, GlobalDecl, Import, NonlocalDecl,
MDEF, Node
)
BytesExpr, UnicodeExpr, FloatExpr, OpExpr, UnaryExpr, CastExpr, RevealTypeExpr, SuperExpr,
TypeApplication, DictExpr, SliceExpr, LambdaExpr, TempNode, SymbolTableNode,
Context, ListComprehension, ConditionalExpr, GeneratorExpr,
Decorator, SetExpr, TypeVarExpr, NewTypeExpr, PrintStmt,
LITERAL_TYPE, BreakStmt, PassStmt, ContinueStmt, ComparisonExpr, StarExpr,
YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension,
DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr,
RefExpr, YieldExpr, BackquoteExpr, Import, ImportFrom, ImportAll, ImportBase,
AwaitExpr, PromoteExpr, Node, EnumCallExpr,
ARG_POS, MDEF,
CONTRAVARIANT, COVARIANT)
from mypy import nodes
from mypy.types import (
Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType,
Expand All @@ -45,7 +50,7 @@
from mypy.semanal import set_callable_name, refers_to_fullname
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.visitor import StatementVisitor
from mypy.visitor import NodeVisitor
from mypy.join import join_types
from mypy.treetransform import TransformVisitor
from mypy.binder import ConditionalTypeBinder, get_declaration
Expand All @@ -70,7 +75,7 @@
])


class TypeChecker(StatementVisitor[None]):
class TypeChecker(NodeVisitor[None]):
"""Mypy type checker.

Type check mypy source files that have been semantically analyzed.
Expand Down Expand Up @@ -2259,21 +2264,7 @@ def visit_break_stmt(self, s: BreakStmt) -> None:

def visit_continue_stmt(self, s: ContinueStmt) -> None:
self.binder.handle_continue()

def visit_exec_stmt(self, s: ExecStmt) -> None:
pass

def visit_global_decl(self, s: GlobalDecl) -> None:
pass

def visit_nonlocal_decl(self, s: NonlocalDecl) -> None:
pass

def visit_var(self, s: Var) -> None:
pass

def visit_pass_stmt(self, s: PassStmt) -> None:
pass
return None

#
# Helpers
Expand Down
25 changes: 24 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
TypeAliasExpr, BackquoteExpr, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF,
TypeAliasExpr, BackquoteExpr, EnumCallExpr,
ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF,
UNBOUND_TVAR, BOUND_TVAR, LITERAL_TYPE
)
from mypy import nodes
Expand Down Expand Up @@ -349,6 +350,12 @@ def check_call(self, callee: Type, args: List[Expression],
"""
arg_messages = arg_messages or self.msg
if isinstance(callee, CallableType):
if (isinstance(callable_node, RefExpr)
and callable_node.fullname in ('enum.Enum', 'enum.IntEnum',
'enum.Flag', 'enum.IntFlag')):
# An Enum() call that failed SemanticAnalyzer.check_enum_call().
return callee.ret_type, callee

if (callee.is_type_obj() and callee.type_object().is_abstract
# Exceptions for Type[...] and classmethod first argument
and not callee.from_type_type and not callee.is_classmethod_class):
Expand Down Expand Up @@ -2199,6 +2206,22 @@ def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type:
# TODO: Perhaps return a type object type?
return AnyType()

def visit_enum_call_expr(self, e: EnumCallExpr) -> Type:
for name, value in zip(e.items, e.values):
if value is not None:
typ = self.accept(value)
if not isinstance(typ, AnyType):
var = e.info.names[name].node
if isinstance(var, Var):
# Inline TypeCheker.set_inferred_type(),
# without the lvalue. (This doesn't really do
# much, since the value attribute is defined
# to have type Any in the typeshed stub.)
var.type = typ
var.is_inferred = True
# TODO: Perhaps return a type object type?
return AnyType()

def visit_typeddict_expr(self, e: TypedDictExpr) -> Type:
# TODO: Perhaps return a type object type?
return AnyType()
Expand Down
19 changes: 19 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,6 +1830,25 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_typeddict_expr(self)


class EnumCallExpr(Expression):
"""Named tuple expression Enum('name', 'val1 val2 ...')."""

# The class representation of this enumerated type
info = None # type: TypeInfo
# The item names (for debugging)
items = None # type: List[str]
values = None # type: List[Optional[Expression]]

def __init__(self, info: 'TypeInfo', items: List[str],
values: List[Optional[Expression]]) -> None:
self.info = info
self.items = items
self.values = values

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_enum_call_expr(self)


class PromoteExpr(Expression):
"""Ducktype class decorator expression _promote(...)."""

Expand Down
136 changes: 135 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SymbolNode,
SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr,
YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr,
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode,
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr,
COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES, ARG_OPT, nongen_builtins,
collections_type_aliases, get_member_expr_fullname,
)
Expand Down Expand Up @@ -1498,6 +1498,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
self.process_typevar_declaration(s)
self.process_namedtuple_definition(s)
self.process_typeddict_definition(s)
self.process_enum_call(s)

if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and
s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and
Expand Down Expand Up @@ -2327,6 +2328,139 @@ def is_classvar(self, typ: Type) -> bool:
def fail_invalid_classvar(self, context: Context) -> None:
self.fail('ClassVar can only be used for assignments in class body', context)

def process_enum_call(self, s: AssignmentStmt) -> None:
"""Check if s defines an Enum; if yes, store the definition in symbol table."""
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr):
return
lvalue = s.lvalues[0]
name = lvalue.name
enum_call = self.check_enum_call(s.rvalue, name)
if enum_call is None:
return
# Yes, it's a valid Enum definition. Add it to the symbol table.
node = self.lookup(name, s)
if node:
node.kind = GDEF # TODO locally defined Enum
node.node = enum_call

def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]:
"""Check if a call defines an Enum.

Example:

A = enum.Enum('A', 'foo bar')

is equivalent to:

class A(enum.Enum):
foo = 1
bar = 2
"""
if not isinstance(node, CallExpr):
return None
call = node
callee = call.callee
if not isinstance(callee, RefExpr):
return None
fullname = callee.fullname
if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test cases for enum.Flag and enum.IntFlag?

return None
items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1])
if not ok:
# Error. Construct dummy return value.
return self.build_enum_call_typeinfo('Enum', [], fullname)
name = cast(StrExpr, call.args[0]).value
if name != var_name or self.is_func_scope():
# Give it a unique name derived from the line number.
name += '@' + str(call.line)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test anonymous Enum that is available externally. Example:

class A:
    def f(self) -> None:
        E = Enum('E', 'a b')
        self.x = E.a
a = A()
reveal_type(a.x)

info = self.build_enum_call_typeinfo(name, items, fullname)
# Store it as a global just in case it would remain anonymous.
# (Or in the nearest class if there is one.)
stnode = SymbolTableNode(GDEF, info, self.cur_mod_id)
if self.type:
self.type.names[name] = stnode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Enum defined in a class body.

else:
self.globals[name] = stnode
call.analyzed = EnumCallExpr(info, items, values)
call.analyzed.set_line(call.line, call.column)
return info

def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo:
base = self.named_type_or_none(fullname)
assert base is not None
info = self.basic_new_typeinfo(name, base)
info.is_enum = True
for item in items:
var = Var(item)
var.info = info
var.is_property = True
info.names[item] = SymbolTableNode(MDEF, var)
return info

def parse_enum_call_args(self, call: CallExpr,
class_name: str) -> Tuple[List[str],
List[Optional[Expression]], bool]:
args = call.args
if len(args) < 2:
return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call)
if len(args) > 2:
return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call)
if call.arg_kinds != [ARG_POS, ARG_POS]:
return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test case that triggers this error.

if not isinstance(args[0], (StrExpr, UnicodeExpr)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Python 2 test case that uses both str and unicode literals for everything.

return self.fail_enum_call_arg(
"%s() expects a string literal as the first argument" % class_name, call)
items = []
values = [] # type: List[Optional[Expression]]
if isinstance(args[1], (StrExpr, UnicodeExpr)):
fields = args[1].value
for field in fields.replace(',', ' ').split():
items.append(field)
elif isinstance(args[1], (TupleExpr, ListExpr)):
seq_items = args[1].items
if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items):
items = [cast(StrExpr, seq_item).value for seq_item in seq_items]
elif all(isinstance(seq_item, (TupleExpr, ListExpr))
and len(seq_item.items) == 2
and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr))
for seq_item in seq_items):
for seq_item in seq_items:
assert isinstance(seq_item, (TupleExpr, ListExpr))
name, value = seq_item.items
assert isinstance(name, (StrExpr, UnicodeExpr))
items.append(name.value)
values.append(value)
else:
return self.fail_enum_call_arg(
"%s() with tuple or list expects strings or (name, value) pairs" %
class_name,
call)
elif isinstance(args[1], DictExpr):
for key, value in args[1].items:
if not isinstance(key, (StrExpr, UnicodeExpr)):
return self.fail_enum_call_arg(
"%s() with dict literal requires string literals" % class_name, call)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to also trigger this error in tests.

items.append(key.value)
values.append(value)
else:
# TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}?
return self.fail_enum_call_arg(
"%s() expects a string, tuple, list or dict literal as the second argument" %
class_name,
call)
if len(items) == 0:
return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call)
if not values:
values = [None] * len(items)
assert len(items) == len(values)
return items, values, True

def fail_enum_call_arg(self, message: str,
context: Context) -> Tuple[List[str],
List[Optional[Expression]], bool]:
self.fail(message, context)
return [], [], False

def visit_decorator(self, dec: Decorator) -> None:
for d in dec.decorators:
d.accept(self)
Expand Down
3 changes: 3 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> str:
o.info.name(),
o.info.tuple_type)

def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> str:
return 'EnumCallExpr:{}({}, {})'.format(o.line, o.info.name(), o.items)

def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> str:
return 'TypedDictExpr:{}({})'.format(o.line,
o.info.name())
Expand Down
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
'check-newsyntax.test',
'check-underscores.test',
'check-classvar.test',
'check-enum.test',
]


Expand Down
3 changes: 1 addition & 2 deletions mypy/test/testpythoneval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
python_eval_files = ['pythoneval.test',
'python2eval.test']

python_34_eval_files = ['pythoneval-asyncio.test',
'pythoneval-enum.test']
python_34_eval_files = ['pythoneval-asyncio.test']

# Path to Python 3 interpreter
python3_path = sys.executable
Expand Down
6 changes: 5 additions & 1 deletion mypy/treetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
ComparisonExpr, TempNode, StarExpr, Statement, Expression,
YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension,
DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr,
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, OverloadPart
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr,
OverloadPart, EnumCallExpr,
)
from mypy.types import Type, FunctionLike
from mypy.traverser import TraverserVisitor
Expand Down Expand Up @@ -486,6 +487,9 @@ def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr:
def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr:
return NamedTupleExpr(node.info)

def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr:
return EnumCallExpr(node.info, node.items, node.values)

def visit_typeddict_expr(self, node: TypedDictExpr) -> Node:
return TypedDictExpr(node.info)

Expand Down
7 changes: 7 additions & 0 deletions mypy/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
pass

@abstractmethod
def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
pass

@abstractmethod
def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
pass
Expand Down Expand Up @@ -514,6 +518,9 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T:
def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
pass

def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T:
pass

def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T:
pass

Expand Down
Loading