|
| 1 | +from typing import List, Optional |
| 2 | + |
| 3 | +from mypy.nodes import ( |
| 4 | + ARG_OPT, ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncBase, |
| 5 | + FuncDef, PassStmt, RefExpr, SymbolTableNode, Var |
| 6 | +) |
| 7 | +from mypy.plugin import ClassDefContext |
| 8 | +from mypy.semanal import set_callable_name |
| 9 | +from mypy.types import CallableType, Overloaded, Type, TypeVarDef |
| 10 | +from mypy.typevars import fill_typevars |
| 11 | + |
| 12 | + |
| 13 | +def _get_decorator_bool_argument( |
| 14 | + ctx: ClassDefContext, |
| 15 | + name: str, |
| 16 | + default: bool, |
| 17 | +) -> bool: |
| 18 | + """Return the bool argument for the decorator. |
| 19 | +
|
| 20 | + This handles both @decorator(...) and @decorator. |
| 21 | + """ |
| 22 | + if isinstance(ctx.reason, CallExpr): |
| 23 | + return _get_bool_argument(ctx, ctx.reason, name, default) |
| 24 | + else: |
| 25 | + return default |
| 26 | + |
| 27 | + |
| 28 | +def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, |
| 29 | + name: str, default: bool) -> bool: |
| 30 | + """Return the boolean value for an argument to a call or the |
| 31 | + default if it's not found. |
| 32 | + """ |
| 33 | + attr_value = _get_argument(expr, name) |
| 34 | + if attr_value: |
| 35 | + ret = ctx.api.parse_bool(attr_value) |
| 36 | + if ret is None: |
| 37 | + ctx.api.fail('"{}" argument must be True or False.'.format(name), expr) |
| 38 | + return default |
| 39 | + return ret |
| 40 | + return default |
| 41 | + |
| 42 | + |
| 43 | +def _get_argument(call: CallExpr, name: str) -> Optional[Expression]: |
| 44 | + """Return the expression for the specific argument.""" |
| 45 | + # To do this we use the CallableType of the callee to find the FormalArgument, |
| 46 | + # then walk the actual CallExpr looking for the appropriate argument. |
| 47 | + # |
| 48 | + # Note: I'm not hard-coding the index so that in the future we can support other |
| 49 | + # attrib and class makers. |
| 50 | + callee_type = None |
| 51 | + if (isinstance(call.callee, RefExpr) |
| 52 | + and isinstance(call.callee.node, (Var, FuncBase)) |
| 53 | + and call.callee.node.type): |
| 54 | + callee_node_type = call.callee.node.type |
| 55 | + if isinstance(callee_node_type, Overloaded): |
| 56 | + # We take the last overload. |
| 57 | + callee_type = callee_node_type.items()[-1] |
| 58 | + elif isinstance(callee_node_type, CallableType): |
| 59 | + callee_type = callee_node_type |
| 60 | + |
| 61 | + if not callee_type: |
| 62 | + return None |
| 63 | + |
| 64 | + argument = callee_type.argument_by_name(name) |
| 65 | + if not argument: |
| 66 | + return None |
| 67 | + assert argument.name |
| 68 | + |
| 69 | + for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)): |
| 70 | + if argument.pos is not None and not attr_name and i == argument.pos: |
| 71 | + return attr_value |
| 72 | + if attr_name == argument.name: |
| 73 | + return attr_value |
| 74 | + return None |
| 75 | + |
| 76 | + |
| 77 | +def _add_method( |
| 78 | + ctx: ClassDefContext, |
| 79 | + name: str, |
| 80 | + args: List[Argument], |
| 81 | + return_type: Type, |
| 82 | + self_type: Optional[Type] = None, |
| 83 | + tvar_def: Optional[TypeVarDef] = None, |
| 84 | +) -> None: |
| 85 | + """Adds a new method to a class. |
| 86 | + """ |
| 87 | + info = ctx.cls.info |
| 88 | + self_type = self_type or fill_typevars(info) |
| 89 | + function_type = ctx.api.named_type('__builtins__.function') |
| 90 | + |
| 91 | + args = [Argument(Var('self'), self_type, None, ARG_POS)] + args |
| 92 | + arg_types, arg_names, arg_kinds = [], [], [] |
| 93 | + for arg in args: |
| 94 | + assert arg.type_annotation, 'All arguments must be fully typed.' |
| 95 | + arg_types.append(arg.type_annotation) |
| 96 | + arg_names.append(arg.variable.name()) |
| 97 | + arg_kinds.append(arg.kind) |
| 98 | + |
| 99 | + signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) |
| 100 | + if tvar_def: |
| 101 | + signature.variables = [tvar_def] |
| 102 | + |
| 103 | + func = FuncDef(name, args, Block([PassStmt()])) |
| 104 | + func.info = info |
| 105 | + func.type = set_callable_name(signature, func) |
| 106 | + func._fullname = info.fullname() + '.' + name |
| 107 | + func.line = info.line |
| 108 | + |
| 109 | + info.names[name] = SymbolTableNode(MDEF, func) |
| 110 | + info.defn.defs.body.append(func) |
0 commit comments