Skip to content

Add __match_args__ to Node types #13243

@dosisod

Description

@dosisod

I am working on a project which makes heavy use of Mypy AST trees/nodes. and it would be nice if there was less boilerplate when doing pattern matching against nodes (especially for more complex scenarios).

For example, parsing the following expression:

1 + 2

Results in an OpExpr expression node (ignoring setup, see below), which I can then pattern match like so:

match o:
    case OpExpr(
        op="+",
        left=IntExpr(value=1),
        right=IntExpr(value=2),
    ):
        print("found")

    case _:
        print("not found")

This works, but what would be really nice is being able to do this:

match o:
    case OpExpr("+", IntExpr(1), IntExpr(2)):
        print("found")

    case _:
        print("not found")

But doing so causes the following exception:

Traceback (most recent call last):
  File "[redacted]/main.py", line 39, in <module>
    tree.accept(Visitor())
  File "mypy/nodes.py", line 347, in accept
  File "mypy/traverser.py", line 42, in visit_mypy_file
  File "mypy/nodes.py", line 1128, in accept
  File "mypy/traverser.py", line 89, in visit_expression_stmt
  File "mypy/nodes.py", line 1888, in accept
  File "[redacted]/main.py", line 30, in visit_op_expr
    case OpExpr(IntExpr(1), "+", IntExpr(2)):
TypeError: OpExpr() accepts 0 positional sub-patterns (3 given)

All that would need to done to fix this (I think) is the following:

 class OpExpr(Expression):
     """Binary operation (other than . or [] or comparison operators,
     which have specific nodes)."""
 
     __slots__ = ('op', 'left', 'right',
                  'method_type', 'right_always', 'right_unreachable')
+
+    __match_args__ ('op', 'left', 'right')

And then do that for every node we would (feasibly) pattern match against. I am willing to make a PR for this, assuming all that needs to be done is add a __match_args__ line to a few classes.

Full code example
from mypy.build import build
from mypy.main import process_options
from mypy.nodes import IntExpr, OpExpr
from mypy.traverser import TraverserVisitor

files, opt = process_options(["-c", "1 + 2"])
opt.fine_grained_incremental = True

result = build(files, options=opt)


class Visitor(TraverserVisitor):
    def visit_op_expr(self, o: OpExpr) -> None:
        super().visit_op_expr(o)

        match o:
            case OpExpr(
                left=IntExpr(value=1),
                op="+",
                right=IntExpr(value=2),
            ):
                print("found")

            case _:
                print("not found")


tree = result.graph['__main__'].tree

if tree:
    tree.accept(Visitor())

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions