Skip to content

Commit 17a544b

Browse files
gh-119180: Avoid going through AST and eval() when possible in annotationlib (#124337)
Often, ForwardRefs represent a single simple name. In that case, we can avoid going through the overhead of creating AST nodes and code objects and calling eval(): we can simply look up the name directly in the relevant namespaces. Co-authored-by: Victor Stinner <[email protected]>
1 parent 9d8f2d8 commit 17a544b

File tree

2 files changed

+88
-28
lines changed

2 files changed

+88
-28
lines changed

Lib/annotationlib.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Helpers for introspecting and wrapping annotations."""
22

33
import ast
4+
import builtins
45
import enum
56
import functools
7+
import keyword
68
import sys
79
import types
810

@@ -154,8 +156,19 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
154156
globals[param_name] = param
155157
locals.pop(param_name, None)
156158

157-
code = self.__forward_code__
158-
value = eval(code, globals=globals, locals=locals)
159+
arg = self.__forward_arg__
160+
if arg.isidentifier() and not keyword.iskeyword(arg):
161+
if arg in locals:
162+
value = locals[arg]
163+
elif arg in globals:
164+
value = globals[arg]
165+
elif hasattr(builtins, arg):
166+
return getattr(builtins, arg)
167+
else:
168+
raise NameError(arg)
169+
else:
170+
code = self.__forward_code__
171+
value = eval(code, globals=globals, locals=locals)
159172
self.__forward_evaluated__ = True
160173
self.__forward_value__ = value
161174
return value
@@ -254,7 +267,9 @@ class _Stringifier:
254267
__slots__ = _SLOTS
255268

256269
def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
257-
assert isinstance(node, ast.AST)
270+
# Either an AST node or a simple str (for the common case where a ForwardRef
271+
# represent a single name).
272+
assert isinstance(node, (ast.AST, str))
258273
self.__arg__ = None
259274
self.__forward_evaluated__ = False
260275
self.__forward_value__ = None
@@ -267,18 +282,26 @@ def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
267282
self.__cell__ = cell
268283
self.__owner__ = owner
269284

270-
def __convert(self, other):
285+
def __convert_to_ast(self, other):
271286
if isinstance(other, _Stringifier):
287+
if isinstance(other.__ast_node__, str):
288+
return ast.Name(id=other.__ast_node__)
272289
return other.__ast_node__
273290
elif isinstance(other, slice):
274291
return ast.Slice(
275-
lower=self.__convert(other.start) if other.start is not None else None,
276-
upper=self.__convert(other.stop) if other.stop is not None else None,
277-
step=self.__convert(other.step) if other.step is not None else None,
292+
lower=self.__convert_to_ast(other.start) if other.start is not None else None,
293+
upper=self.__convert_to_ast(other.stop) if other.stop is not None else None,
294+
step=self.__convert_to_ast(other.step) if other.step is not None else None,
278295
)
279296
else:
280297
return ast.Constant(value=other)
281298

299+
def __get_ast(self):
300+
node = self.__ast_node__
301+
if isinstance(node, str):
302+
return ast.Name(id=node)
303+
return node
304+
282305
def __make_new(self, node):
283306
return _Stringifier(
284307
node, self.__globals__, self.__owner__, self.__forward_is_class__
@@ -292,38 +315,37 @@ def __hash__(self):
292315
def __getitem__(self, other):
293316
# Special case, to avoid stringifying references to class-scoped variables
294317
# as '__classdict__["x"]'.
295-
if (
296-
isinstance(self.__ast_node__, ast.Name)
297-
and self.__ast_node__.id == "__classdict__"
298-
):
318+
if self.__ast_node__ == "__classdict__":
299319
raise KeyError
300320
if isinstance(other, tuple):
301-
elts = [self.__convert(elt) for elt in other]
321+
elts = [self.__convert_to_ast(elt) for elt in other]
302322
other = ast.Tuple(elts)
303323
else:
304-
other = self.__convert(other)
324+
other = self.__convert_to_ast(other)
305325
assert isinstance(other, ast.AST), repr(other)
306-
return self.__make_new(ast.Subscript(self.__ast_node__, other))
326+
return self.__make_new(ast.Subscript(self.__get_ast(), other))
307327

308328
def __getattr__(self, attr):
309-
return self.__make_new(ast.Attribute(self.__ast_node__, attr))
329+
return self.__make_new(ast.Attribute(self.__get_ast(), attr))
310330

311331
def __call__(self, *args, **kwargs):
312332
return self.__make_new(
313333
ast.Call(
314-
self.__ast_node__,
315-
[self.__convert(arg) for arg in args],
334+
self.__get_ast(),
335+
[self.__convert_to_ast(arg) for arg in args],
316336
[
317-
ast.keyword(key, self.__convert(value))
337+
ast.keyword(key, self.__convert_to_ast(value))
318338
for key, value in kwargs.items()
319339
],
320340
)
321341
)
322342

323343
def __iter__(self):
324-
yield self.__make_new(ast.Starred(self.__ast_node__))
344+
yield self.__make_new(ast.Starred(self.__get_ast()))
325345

326346
def __repr__(self):
347+
if isinstance(self.__ast_node__, str):
348+
return self.__ast_node__
327349
return ast.unparse(self.__ast_node__)
328350

329351
def __format__(self, format_spec):
@@ -332,7 +354,7 @@ def __format__(self, format_spec):
332354
def _make_binop(op: ast.AST):
333355
def binop(self, other):
334356
return self.__make_new(
335-
ast.BinOp(self.__ast_node__, op, self.__convert(other))
357+
ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
336358
)
337359

338360
return binop
@@ -356,7 +378,7 @@ def binop(self, other):
356378
def _make_rbinop(op: ast.AST):
357379
def rbinop(self, other):
358380
return self.__make_new(
359-
ast.BinOp(self.__convert(other), op, self.__ast_node__)
381+
ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
360382
)
361383

362384
return rbinop
@@ -381,9 +403,9 @@ def _make_compare(op):
381403
def compare(self, other):
382404
return self.__make_new(
383405
ast.Compare(
384-
left=self.__ast_node__,
406+
left=self.__get_ast(),
385407
ops=[op],
386-
comparators=[self.__convert(other)],
408+
comparators=[self.__convert_to_ast(other)],
387409
)
388410
)
389411

@@ -400,7 +422,7 @@ def compare(self, other):
400422

401423
def _make_unary_op(op):
402424
def unary_op(self):
403-
return self.__make_new(ast.UnaryOp(op, self.__ast_node__))
425+
return self.__make_new(ast.UnaryOp(op, self.__get_ast()))
404426

405427
return unary_op
406428

@@ -422,7 +444,7 @@ def __init__(self, namespace, globals=None, owner=None, is_class=False):
422444

423445
def __missing__(self, key):
424446
fwdref = _Stringifier(
425-
ast.Name(id=key),
447+
key,
426448
globals=self.globals,
427449
owner=self.owner,
428450
is_class=self.is_class,
@@ -480,7 +502,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
480502
name = freevars[i]
481503
else:
482504
name = "__cell__"
483-
fwdref = _Stringifier(ast.Name(id=name))
505+
fwdref = _Stringifier(name)
484506
new_closure.append(types.CellType(fwdref))
485507
closure = tuple(new_closure)
486508
else:
@@ -532,7 +554,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
532554
else:
533555
name = "__cell__"
534556
fwdref = _Stringifier(
535-
ast.Name(id=name),
557+
name,
536558
cell=cell,
537559
owner=owner,
538560
globals=annotate.__globals__,
@@ -555,6 +577,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
555577
result = func(Format.VALUE)
556578
for obj in globals.stringifiers:
557579
obj.__class__ = ForwardRef
580+
if isinstance(obj.__ast_node__, str):
581+
obj.__arg__ = obj.__ast_node__
582+
obj.__ast_node__ = None
558583
return result
559584
elif format == Format.VALUE:
560585
# Should be impossible because __annotate__ functions must not raise

Lib/test/test_annotationlib.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for the annotations module."""
22

33
import annotationlib
4+
import builtins
45
import collections
56
import functools
67
import itertools
@@ -280,7 +281,14 @@ class Gen[T]:
280281

281282
def test_fwdref_with_module(self):
282283
self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format)
283-
self.assertIs(ForwardRef("Counter", module="collections").evaluate(), collections.Counter)
284+
self.assertIs(
285+
ForwardRef("Counter", module="collections").evaluate(),
286+
collections.Counter
287+
)
288+
self.assertEqual(
289+
ForwardRef("Counter[int]", module="collections").evaluate(),
290+
collections.Counter[int],
291+
)
284292

285293
with self.assertRaises(NameError):
286294
# If globals are passed explicitly, we don't look at the module dict
@@ -305,6 +313,33 @@ def test_fwdref_value_is_cached(self):
305313
self.assertIs(fr.evaluate(globals={"hello": str}), str)
306314
self.assertIs(fr.evaluate(), str)
307315

316+
def test_fwdref_with_owner(self):
317+
self.assertEqual(
318+
ForwardRef("Counter[int]", owner=collections).evaluate(),
319+
collections.Counter[int],
320+
)
321+
322+
def test_name_lookup_without_eval(self):
323+
# test the codepath where we look up simple names directly in the
324+
# namespaces without going through eval()
325+
self.assertIs(ForwardRef("int").evaluate(), int)
326+
self.assertIs(ForwardRef("int").evaluate(locals={"int": str}), str)
327+
self.assertIs(ForwardRef("int").evaluate(locals={"int": float}, globals={"int": str}), float)
328+
self.assertIs(ForwardRef("int").evaluate(globals={"int": str}), str)
329+
with support.swap_attr(builtins, "int", dict):
330+
self.assertIs(ForwardRef("int").evaluate(), dict)
331+
332+
with self.assertRaises(NameError):
333+
ForwardRef("doesntexist").evaluate()
334+
335+
def test_fwdref_invalid_syntax(self):
336+
fr = ForwardRef("if")
337+
with self.assertRaises(SyntaxError):
338+
fr.evaluate()
339+
fr = ForwardRef("1+")
340+
with self.assertRaises(SyntaxError):
341+
fr.evaluate()
342+
308343

309344
class TestGetAnnotations(unittest.TestCase):
310345
def test_builtin_type(self):

0 commit comments

Comments
 (0)