-
-
Notifications
You must be signed in to change notification settings - Fork 3k
Closed
Labels
Description
I am working on a project which makes heavy use of Mypy AST trees/nodes. and it would be nice if there was less boilerplate when doing pattern matching against nodes (especially for more complex scenarios).
For example, parsing the following expression:
1 + 2
Results in an OpExpr
expression node (ignoring setup, see below), which I can then pattern match like so:
match o:
case OpExpr(
op="+",
left=IntExpr(value=1),
right=IntExpr(value=2),
):
print("found")
case _:
print("not found")
This works, but what would be really nice is being able to do this:
match o:
case OpExpr("+", IntExpr(1), IntExpr(2)):
print("found")
case _:
print("not found")
But doing so causes the following exception:
Traceback (most recent call last):
File "[redacted]/main.py", line 39, in <module>
tree.accept(Visitor())
File "mypy/nodes.py", line 347, in accept
File "mypy/traverser.py", line 42, in visit_mypy_file
File "mypy/nodes.py", line 1128, in accept
File "mypy/traverser.py", line 89, in visit_expression_stmt
File "mypy/nodes.py", line 1888, in accept
File "[redacted]/main.py", line 30, in visit_op_expr
case OpExpr(IntExpr(1), "+", IntExpr(2)):
TypeError: OpExpr() accepts 0 positional sub-patterns (3 given)
All that would need to done to fix this (I think) is the following:
class OpExpr(Expression):
"""Binary operation (other than . or [] or comparison operators,
which have specific nodes)."""
__slots__ = ('op', 'left', 'right',
'method_type', 'right_always', 'right_unreachable')
+
+ __match_args__ ('op', 'left', 'right')
And then do that for every node we would (feasibly) pattern match against. I am willing to make a PR for this, assuming all that needs to be done is add a __match_args__
line to a few classes.
Full code example
from mypy.build import build
from mypy.main import process_options
from mypy.nodes import IntExpr, OpExpr
from mypy.traverser import TraverserVisitor
files, opt = process_options(["-c", "1 + 2"])
opt.fine_grained_incremental = True
result = build(files, options=opt)
class Visitor(TraverserVisitor):
def visit_op_expr(self, o: OpExpr) -> None:
super().visit_op_expr(o)
match o:
case OpExpr(
left=IntExpr(value=1),
op="+",
right=IntExpr(value=2),
):
print("found")
case _:
print("not found")
tree = result.graph['__main__'].tree
if tree:
tree.accept(Visitor())