1
1
"""Helpers for introspecting and wrapping annotations."""
2
2
3
3
import ast
4
+ import builtins
4
5
import enum
5
6
import functools
7
+ import keyword
6
8
import sys
7
9
import types
8
10
@@ -154,8 +156,19 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
154
156
globals [param_name ] = param
155
157
locals .pop (param_name , None )
156
158
157
- code = self .__forward_code__
158
- value = eval (code , globals = globals , locals = locals )
159
+ arg = self .__forward_arg__
160
+ if arg .isidentifier () and not keyword .iskeyword (arg ):
161
+ if arg in locals :
162
+ value = locals [arg ]
163
+ elif arg in globals :
164
+ value = globals [arg ]
165
+ elif hasattr (builtins , arg ):
166
+ return getattr (builtins , arg )
167
+ else :
168
+ raise NameError (arg )
169
+ else :
170
+ code = self .__forward_code__
171
+ value = eval (code , globals = globals , locals = locals )
159
172
self .__forward_evaluated__ = True
160
173
self .__forward_value__ = value
161
174
return value
@@ -254,7 +267,9 @@ class _Stringifier:
254
267
__slots__ = _SLOTS
255
268
256
269
def __init__ (self , node , globals = None , owner = None , is_class = False , cell = None ):
257
- assert isinstance (node , ast .AST )
270
+ # Either an AST node or a simple str (for the common case where a ForwardRef
271
+ # represent a single name).
272
+ assert isinstance (node , (ast .AST , str ))
258
273
self .__arg__ = None
259
274
self .__forward_evaluated__ = False
260
275
self .__forward_value__ = None
@@ -267,18 +282,26 @@ def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
267
282
self .__cell__ = cell
268
283
self .__owner__ = owner
269
284
270
- def __convert (self , other ):
285
+ def __convert_to_ast (self , other ):
271
286
if isinstance (other , _Stringifier ):
287
+ if isinstance (other .__ast_node__ , str ):
288
+ return ast .Name (id = other .__ast_node__ )
272
289
return other .__ast_node__
273
290
elif isinstance (other , slice ):
274
291
return ast .Slice (
275
- lower = self .__convert (other .start ) if other .start is not None else None ,
276
- upper = self .__convert (other .stop ) if other .stop is not None else None ,
277
- step = self .__convert (other .step ) if other .step is not None else None ,
292
+ lower = self .__convert_to_ast (other .start ) if other .start is not None else None ,
293
+ upper = self .__convert_to_ast (other .stop ) if other .stop is not None else None ,
294
+ step = self .__convert_to_ast (other .step ) if other .step is not None else None ,
278
295
)
279
296
else :
280
297
return ast .Constant (value = other )
281
298
299
+ def __get_ast (self ):
300
+ node = self .__ast_node__
301
+ if isinstance (node , str ):
302
+ return ast .Name (id = node )
303
+ return node
304
+
282
305
def __make_new (self , node ):
283
306
return _Stringifier (
284
307
node , self .__globals__ , self .__owner__ , self .__forward_is_class__
@@ -292,38 +315,37 @@ def __hash__(self):
292
315
def __getitem__ (self , other ):
293
316
# Special case, to avoid stringifying references to class-scoped variables
294
317
# as '__classdict__["x"]'.
295
- if (
296
- isinstance (self .__ast_node__ , ast .Name )
297
- and self .__ast_node__ .id == "__classdict__"
298
- ):
318
+ if self .__ast_node__ == "__classdict__" :
299
319
raise KeyError
300
320
if isinstance (other , tuple ):
301
- elts = [self .__convert (elt ) for elt in other ]
321
+ elts = [self .__convert_to_ast (elt ) for elt in other ]
302
322
other = ast .Tuple (elts )
303
323
else :
304
- other = self .__convert (other )
324
+ other = self .__convert_to_ast (other )
305
325
assert isinstance (other , ast .AST ), repr (other )
306
- return self .__make_new (ast .Subscript (self .__ast_node__ , other ))
326
+ return self .__make_new (ast .Subscript (self .__get_ast () , other ))
307
327
308
328
def __getattr__ (self , attr ):
309
- return self .__make_new (ast .Attribute (self .__ast_node__ , attr ))
329
+ return self .__make_new (ast .Attribute (self .__get_ast () , attr ))
310
330
311
331
def __call__ (self , * args , ** kwargs ):
312
332
return self .__make_new (
313
333
ast .Call (
314
- self .__ast_node__ ,
315
- [self .__convert (arg ) for arg in args ],
334
+ self .__get_ast () ,
335
+ [self .__convert_to_ast (arg ) for arg in args ],
316
336
[
317
- ast .keyword (key , self .__convert (value ))
337
+ ast .keyword (key , self .__convert_to_ast (value ))
318
338
for key , value in kwargs .items ()
319
339
],
320
340
)
321
341
)
322
342
323
343
def __iter__ (self ):
324
- yield self .__make_new (ast .Starred (self .__ast_node__ ))
344
+ yield self .__make_new (ast .Starred (self .__get_ast () ))
325
345
326
346
def __repr__ (self ):
347
+ if isinstance (self .__ast_node__ , str ):
348
+ return self .__ast_node__
327
349
return ast .unparse (self .__ast_node__ )
328
350
329
351
def __format__ (self , format_spec ):
@@ -332,7 +354,7 @@ def __format__(self, format_spec):
332
354
def _make_binop (op : ast .AST ):
333
355
def binop (self , other ):
334
356
return self .__make_new (
335
- ast .BinOp (self .__ast_node__ , op , self .__convert (other ))
357
+ ast .BinOp (self .__get_ast () , op , self .__convert_to_ast (other ))
336
358
)
337
359
338
360
return binop
@@ -356,7 +378,7 @@ def binop(self, other):
356
378
def _make_rbinop (op : ast .AST ):
357
379
def rbinop (self , other ):
358
380
return self .__make_new (
359
- ast .BinOp (self .__convert (other ), op , self .__ast_node__ )
381
+ ast .BinOp (self .__convert_to_ast (other ), op , self .__get_ast () )
360
382
)
361
383
362
384
return rbinop
@@ -381,9 +403,9 @@ def _make_compare(op):
381
403
def compare (self , other ):
382
404
return self .__make_new (
383
405
ast .Compare (
384
- left = self .__ast_node__ ,
406
+ left = self .__get_ast () ,
385
407
ops = [op ],
386
- comparators = [self .__convert (other )],
408
+ comparators = [self .__convert_to_ast (other )],
387
409
)
388
410
)
389
411
@@ -400,7 +422,7 @@ def compare(self, other):
400
422
401
423
def _make_unary_op (op ):
402
424
def unary_op (self ):
403
- return self .__make_new (ast .UnaryOp (op , self .__ast_node__ ))
425
+ return self .__make_new (ast .UnaryOp (op , self .__get_ast () ))
404
426
405
427
return unary_op
406
428
@@ -422,7 +444,7 @@ def __init__(self, namespace, globals=None, owner=None, is_class=False):
422
444
423
445
def __missing__ (self , key ):
424
446
fwdref = _Stringifier (
425
- ast . Name ( id = key ) ,
447
+ key ,
426
448
globals = self .globals ,
427
449
owner = self .owner ,
428
450
is_class = self .is_class ,
@@ -480,7 +502,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
480
502
name = freevars [i ]
481
503
else :
482
504
name = "__cell__"
483
- fwdref = _Stringifier (ast . Name ( id = name ) )
505
+ fwdref = _Stringifier (name )
484
506
new_closure .append (types .CellType (fwdref ))
485
507
closure = tuple (new_closure )
486
508
else :
@@ -532,7 +554,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
532
554
else :
533
555
name = "__cell__"
534
556
fwdref = _Stringifier (
535
- ast . Name ( id = name ) ,
557
+ name ,
536
558
cell = cell ,
537
559
owner = owner ,
538
560
globals = annotate .__globals__ ,
@@ -555,6 +577,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
555
577
result = func (Format .VALUE )
556
578
for obj in globals .stringifiers :
557
579
obj .__class__ = ForwardRef
580
+ if isinstance (obj .__ast_node__ , str ):
581
+ obj .__arg__ = obj .__ast_node__
582
+ obj .__ast_node__ = None
558
583
return result
559
584
elif format == Format .VALUE :
560
585
# Should be impossible because __annotate__ functions must not raise
0 commit comments