Skip to content

Commit 31bec6f

Browse files
authored
bpo-43897: AST validation for pattern matching nodes (GH24771)
1 parent 53b9458 commit 31bec6f

File tree

2 files changed

+265
-32
lines changed

2 files changed

+265
-32
lines changed

Lib/test/test_ast.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def test_constant_as_name(self):
696696
for constant in "True", "False", "None":
697697
expr = ast.Expression(ast.Name(constant, ast.Load()))
698698
ast.fix_missing_locations(expr)
699-
with self.assertRaisesRegex(ValueError, f"Name node can't be used with '{constant}' constant"):
699+
with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
700700
compile(expr, "<test>", "eval")
701701

702702
def test_precedence_enum(self):
@@ -1507,6 +1507,147 @@ def test_stdlib_validates(self):
15071507
mod = ast.parse(source, fn)
15081508
compile(mod, fn, "exec")
15091509

1510+
constant_1 = ast.Constant(1)
1511+
pattern_1 = ast.MatchValue(constant_1)
1512+
1513+
constant_x = ast.Constant('x')
1514+
pattern_x = ast.MatchValue(constant_x)
1515+
1516+
constant_true = ast.Constant(True)
1517+
pattern_true = ast.MatchSingleton(True)
1518+
1519+
name_carter = ast.Name('carter', ast.Load())
1520+
1521+
_MATCH_PATTERNS = [
1522+
ast.MatchValue(
1523+
ast.Attribute(
1524+
ast.Attribute(
1525+
ast.Name('x', ast.Store()),
1526+
'y', ast.Load()
1527+
),
1528+
'z', ast.Load()
1529+
)
1530+
),
1531+
ast.MatchValue(
1532+
ast.Attribute(
1533+
ast.Attribute(
1534+
ast.Name('x', ast.Load()),
1535+
'y', ast.Store()
1536+
),
1537+
'z', ast.Load()
1538+
)
1539+
),
1540+
ast.MatchValue(
1541+
ast.Constant(...)
1542+
),
1543+
ast.MatchValue(
1544+
ast.Constant(True)
1545+
),
1546+
ast.MatchValue(
1547+
ast.Constant((1,2,3))
1548+
),
1549+
ast.MatchSingleton('string'),
1550+
ast.MatchSequence([
1551+
ast.MatchSingleton('string')
1552+
]),
1553+
ast.MatchSequence(
1554+
[
1555+
ast.MatchSequence(
1556+
[
1557+
ast.MatchSingleton('string')
1558+
]
1559+
)
1560+
]
1561+
),
1562+
ast.MatchMapping(
1563+
[constant_1, constant_true],
1564+
[pattern_x]
1565+
),
1566+
ast.MatchMapping(
1567+
[constant_true, constant_1],
1568+
[pattern_x, pattern_1],
1569+
rest='True'
1570+
),
1571+
ast.MatchMapping(
1572+
[constant_true, ast.Starred(ast.Name('lol', ast.Load()), ast.Load())],
1573+
[pattern_x, pattern_1],
1574+
rest='legit'
1575+
),
1576+
ast.MatchClass(
1577+
ast.Attribute(
1578+
ast.Attribute(
1579+
constant_x,
1580+
'y', ast.Load()),
1581+
'z', ast.Load()),
1582+
patterns=[], kwd_attrs=[], kwd_patterns=[]
1583+
),
1584+
ast.MatchClass(
1585+
name_carter,
1586+
patterns=[],
1587+
kwd_attrs=['True'],
1588+
kwd_patterns=[pattern_1]
1589+
),
1590+
ast.MatchClass(
1591+
name_carter,
1592+
patterns=[],
1593+
kwd_attrs=[],
1594+
kwd_patterns=[pattern_1]
1595+
),
1596+
ast.MatchClass(
1597+
name_carter,
1598+
patterns=[ast.MatchSingleton('string')],
1599+
kwd_attrs=[],
1600+
kwd_patterns=[]
1601+
),
1602+
ast.MatchClass(
1603+
name_carter,
1604+
patterns=[ast.MatchStar()],
1605+
kwd_attrs=[],
1606+
kwd_patterns=[]
1607+
),
1608+
ast.MatchClass(
1609+
name_carter,
1610+
patterns=[],
1611+
kwd_attrs=[],
1612+
kwd_patterns=[ast.MatchStar()]
1613+
),
1614+
ast.MatchSequence(
1615+
[
1616+
ast.MatchStar("True")
1617+
]
1618+
),
1619+
ast.MatchAs(
1620+
name='False'
1621+
),
1622+
ast.MatchOr(
1623+
[]
1624+
),
1625+
ast.MatchOr(
1626+
[pattern_1]
1627+
),
1628+
ast.MatchOr(
1629+
[pattern_1, pattern_x, ast.MatchSingleton('xxx')]
1630+
)
1631+
]
1632+
1633+
def test_match_validation_pattern(self):
1634+
name_x = ast.Name('x', ast.Load())
1635+
for pattern in self._MATCH_PATTERNS:
1636+
with self.subTest(ast.dump(pattern, indent=4)):
1637+
node = ast.Match(
1638+
subject=name_x,
1639+
cases = [
1640+
ast.match_case(
1641+
pattern=pattern,
1642+
body = [ast.Pass()]
1643+
)
1644+
]
1645+
)
1646+
node = ast.fix_missing_locations(node)
1647+
module = ast.Module([node], [])
1648+
with self.assertRaises(ValueError):
1649+
compile(module, "<test>", "exec")
1650+
15101651

15111652
class ConstantTests(unittest.TestCase):
15121653
"""Tests on the ast.Constant node type."""

Python/ast.c

Lines changed: 123 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ struct validator {
1515
};
1616

1717
static int validate_stmts(struct validator *, asdl_stmt_seq *);
18-
static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, int);
18+
static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
19+
static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
1920
static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
2021
static int validate_stmt(struct validator *, stmt_ty);
2122
static int validate_expr(struct validator *, expr_ty, expr_context_ty);
@@ -33,7 +34,7 @@ validate_name(PyObject *name)
3334
};
3435
for (int i = 0; forbidden[i] != NULL; i++) {
3536
if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
36-
PyErr_Format(PyExc_ValueError, "Name node can't be used with '%s' constant", forbidden[i]);
37+
PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
3738
return 0;
3839
}
3940
}
@@ -448,6 +449,21 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
448449
switch (exp->kind)
449450
{
450451
case Constant_kind:
452+
/* Ellipsis and immutable sequences are not allowed.
453+
For True, False and None, MatchSingleton() should
454+
be used */
455+
if (!validate_expr(state, exp, Load)) {
456+
return 0;
457+
}
458+
PyObject *literal = exp->v.Constant.value;
459+
if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
460+
PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
461+
PyUnicode_CheckExact(literal)) {
462+
return 1;
463+
}
464+
PyErr_SetString(PyExc_ValueError,
465+
"unexpected constant inside of a literal pattern");
466+
return 0;
451467
case Attribute_kind:
452468
// Constants and attribute lookups are always permitted
453469
return 1;
@@ -465,11 +481,14 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
465481
return 1;
466482
}
467483
break;
484+
case JoinedStr_kind:
485+
// Handled in the later stages
486+
return 1;
468487
default:
469488
break;
470489
}
471-
PyErr_SetString(PyExc_SyntaxError,
472-
"patterns may only match literals and attribute lookups");
490+
PyErr_SetString(PyExc_ValueError,
491+
"patterns may only match literals and attribute lookups");
473492
return 0;
474493
}
475494

@@ -489,51 +508,101 @@ validate_pattern(struct validator *state, pattern_ty p)
489508
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
490509
break;
491510
case MatchSingleton_kind:
492-
// TODO: Check constant is specifically None, True, or False
493-
ret = validate_constant(state, p->v.MatchSingleton.value);
511+
ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
512+
if (!ret) {
513+
PyErr_SetString(PyExc_ValueError,
514+
"MatchSingleton can only contain True, False and None");
515+
}
494516
break;
495517
case MatchSequence_kind:
496-
// TODO: Validate all subpatterns
497-
// return validate_patterns(state, p->v.MatchSequence.patterns);
498-
ret = 1;
518+
ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
499519
break;
500520
case MatchMapping_kind:
501-
// TODO: check "rest" target name is valid
502521
if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
503522
PyErr_SetString(PyExc_ValueError,
504523
"MatchMapping doesn't have the same number of keys as patterns");
505-
return 0;
524+
ret = 0;
525+
break;
506526
}
507-
// null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
508-
// TODO: replace with more restrictive expression validator, as per MatchValue above
509-
if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) {
510-
return 0;
527+
528+
if (p->v.MatchMapping.rest && !validate_name(p->v.MatchMapping.rest)) {
529+
ret = 0;
530+
break;
511531
}
512-
// TODO: Validate all subpatterns
513-
// ret = validate_patterns(state, p->v.MatchMapping.patterns);
514-
ret = 1;
532+
533+
asdl_expr_seq *keys = p->v.MatchMapping.keys;
534+
for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
535+
expr_ty key = asdl_seq_GET(keys, i);
536+
if (key->kind == Constant_kind) {
537+
PyObject *literal = key->v.Constant.value;
538+
if (literal == Py_None || PyBool_Check(literal)) {
539+
/* validate_pattern_match_value will ensure the key
540+
doesn't contain True, False and None but it is
541+
syntactically valid, so we will pass those on in
542+
a special case. */
543+
continue;
544+
}
545+
}
546+
if (!validate_pattern_match_value(state, key)) {
547+
ret = 0;
548+
break;
549+
}
550+
}
551+
552+
ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
515553
break;
516554
case MatchClass_kind:
517555
if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
518556
PyErr_SetString(PyExc_ValueError,
519557
"MatchClass doesn't have the same number of keyword attributes as patterns");
520-
return 0;
558+
ret = 0;
559+
break;
521560
}
522-
// TODO: Restrict cls lookup to being a name or attribute
523561
if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
524-
return 0;
562+
ret = 0;
563+
break;
525564
}
526-
// TODO: Validate all subpatterns
527-
// return validate_patterns(state, p->v.MatchClass.patterns) &&
528-
// validate_patterns(state, p->v.MatchClass.kwd_patterns);
529-
ret = 1;
565+
566+
expr_ty cls = p->v.MatchClass.cls;
567+
while (1) {
568+
if (cls->kind == Name_kind) {
569+
break;
570+
}
571+
else if (cls->kind == Attribute_kind) {
572+
cls = cls->v.Attribute.value;
573+
continue;
574+
}
575+
else {
576+
PyErr_SetString(PyExc_ValueError,
577+
"MatchClass cls field can only contain Name or Attribute nodes.");
578+
state->recursion_depth--;
579+
return 0;
580+
}
581+
}
582+
583+
for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
584+
PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
585+
if (!validate_name(identifier)) {
586+
state->recursion_depth--;
587+
return 0;
588+
}
589+
}
590+
591+
if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
592+
ret = 0;
593+
break;
594+
}
595+
596+
ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
530597
break;
531598
case MatchStar_kind:
532-
// TODO: check target name is valid
533-
ret = 1;
599+
ret = p->v.MatchStar.name == NULL || validate_name(p->v.MatchStar.name);
534600
break;
535601
case MatchAs_kind:
536-
// TODO: check target name is valid
602+
if (p->v.MatchAs.name && !validate_name(p->v.MatchAs.name)) {
603+
ret = 0;
604+
break;
605+
}
537606
if (p->v.MatchAs.pattern == NULL) {
538607
ret = 1;
539608
}
@@ -547,9 +616,13 @@ validate_pattern(struct validator *state, pattern_ty p)
547616
}
548617
break;
549618
case MatchOr_kind:
550-
// TODO: Validate all subpatterns
551-
// return validate_patterns(state, p->v.MatchOr.patterns);
552-
ret = 1;
619+
if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
620+
PyErr_SetString(PyExc_ValueError,
621+
"MatchOr requires at least 2 patterns");
622+
ret = 0;
623+
break;
624+
}
625+
ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
553626
break;
554627
// No default case, so the compiler will emit a warning if new pattern
555628
// kinds are added without being handled here
@@ -815,6 +888,25 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
815888
return 1;
816889
}
817890

891+
static int
892+
validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
893+
{
894+
Py_ssize_t i;
895+
for (i = 0; i < asdl_seq_LEN(patterns); i++) {
896+
pattern_ty pattern = asdl_seq_GET(patterns, i);
897+
if (pattern->kind == MatchStar_kind && !star_ok) {
898+
PyErr_SetString(PyExc_ValueError,
899+
"Can't use MatchStar within this sequence of patterns");
900+
return 0;
901+
}
902+
if (!validate_pattern(state, pattern)) {
903+
return 0;
904+
}
905+
}
906+
return 1;
907+
}
908+
909+
818910
/* See comments in symtable.c. */
819911
#define COMPILER_STACK_FRAME_SCALE 3
820912

0 commit comments

Comments
 (0)