4
4
# --------------------------------------------------------------------------
5
5
6
6
import ast
7
- from .values import DebugInfo
7
+
8
+ from onnxscript import debuginfo
8
9
9
10
10
11
def get_loop_var (for_stmt , converter ):
11
12
if not isinstance (for_stmt .target , ast .Name ):
12
- raise ValueError (DebugInfo (for_stmt , converter ).msg (
13
- "For loop target must be a single variable." ))
13
+ raise ValueError (
14
+ debuginfo .DebugInfo (for_stmt , converter ).msg (
15
+ "For loop target must be a single variable."
16
+ )
17
+ )
14
18
return for_stmt .target .id
15
19
16
20
17
21
def used_vars (expr ):
18
- ''' Return set of all variables used in an expression.'''
22
+ """ Return set of all variables used in an expression."""
19
23
if isinstance (expr , ast .Name ):
20
24
return set ([expr .id ])
21
25
if isinstance (expr , ast .Call ):
@@ -31,22 +35,24 @@ def used_vars(expr):
31
35
32
36
33
37
def local_defs (lhs ):
34
- '''Utility function to return set of assigned/defined
35
- variables in the lhs of an assignment statement.'''
38
+ """Utility function to return set of assigned/defined
39
+ variables in the lhs of an assignment statement."""
40
+
36
41
def get_id (e ):
37
42
assert isinstance (e , ast .Name ), "Only simple assignments supported."
38
43
return e .id
39
44
40
- if ( isinstance (lhs , ast .Tuple ) ):
41
- return set ([ get_id (x ) for x in lhs .elts ])
42
- return set ([ get_id (lhs )])
45
+ if isinstance (lhs , ast .Tuple ):
46
+ return { get_id (x ) for x in lhs .elts }
47
+ return { get_id (lhs )}
43
48
44
49
45
50
def defs (stmt ):
46
- '''
51
+ """
47
52
Return the set of all variables that may be defined (assigned to) in an
48
53
execution of input stmt.
49
- '''
54
+ """
55
+
50
56
def block_defs (block ):
51
57
result = set ()
52
58
for s in block :
@@ -66,7 +72,7 @@ def block_defs(block):
66
72
if isinstance (stmt , ast .Break ):
67
73
return set ()
68
74
try :
69
- if stmt .value .func .id == ' print' :
75
+ if stmt .value .func .id == " print" :
70
76
# Any call to print function are ignored.
71
77
return set ()
72
78
except (TypeError , AttributeError ):
@@ -75,11 +81,12 @@ def block_defs(block):
75
81
76
82
77
83
def do_liveness_analysis (fun , converter ):
78
- '''
84
+ """
79
85
Perform liveness analysis of the given function-ast. The results of the
80
86
analysis are stored directly with each statement-ast `s` as attributes `s.live_in`
81
87
and `s.live_out`.
82
- '''
88
+ """
89
+
83
90
def visit (stmt , live_out ):
84
91
stmt .live_out = live_out
85
92
live = do_visit (stmt , live_out )
@@ -124,22 +131,25 @@ def visitBlock(block, live_out):
124
131
# Break statements in the middle of the loop, however, will require
125
132
# a generalization.
126
133
return live_out
127
- if isinstance (stmt , ast .Expr ) and hasattr (stmt , ' value' ):
134
+ if isinstance (stmt , ast .Expr ) and hasattr (stmt , " value" ):
128
135
# docstring
129
- if hasattr (stmt .value , ' value' ) and isinstance (stmt .value .value , str ):
136
+ if hasattr (stmt .value , " value" ) and isinstance (stmt .value .value , str ):
130
137
# python 3.8+
131
138
return live_out
132
- if hasattr (stmt .value , 's' ) and isinstance (stmt .value .s , str ):
139
+ if hasattr (stmt .value , "s" ) and isinstance (stmt .value .s , str ):
133
140
# python 3.7
134
141
return live_out
135
142
try :
136
- if stmt .value .func .id == ' print' :
143
+ if stmt .value .func .id == " print" :
137
144
# Any call to print function are ignored.
138
145
return live_out
139
146
except (TypeError , AttributeError ):
140
147
pass
141
- raise ValueError (DebugInfo (stmt , converter ).msg (
142
- f"Unsupported statement type { type (stmt )!r} ." ))
148
+ raise ValueError (
149
+ debuginfo .DebugInfo (stmt , converter ).msg (
150
+ f"Unsupported statement type { type (stmt )!r} ."
151
+ )
152
+ )
143
153
144
154
assert isinstance (fun , ast .FunctionDef )
145
155
live = set ()
@@ -148,7 +158,7 @@ def visitBlock(block, live_out):
148
158
149
159
150
160
def exposed_uses (stmts , converter ):
151
- '''
161
+ """
152
162
Return the set of variables that are used before being defined by given block.
153
163
In essence, this identifies the "inputs" to a given code-block.
154
164
For example, consider the following code-block:
@@ -163,7 +173,8 @@ def exposed_uses(stmts, converter):
163
173
the block. Even though the value of y is used within the block, it is assigned
164
174
a value before it is used. However, in contrast, the incoming value of x is used
165
175
(in the first statement). Hence x is included in the exposed_uses.
166
- '''
176
+ """
177
+
167
178
def visitBlock (block , live_out ):
168
179
for stmt in reversed (block ):
169
180
live_out = visit (stmt , live_out )
@@ -180,10 +191,13 @@ def visit(stmt, live_out):
180
191
live1 = visitBlock (stmt .body , live_out )
181
192
live2 = visitBlock (stmt .orelse , live_out )
182
193
return (live1 | live2 ) | used_vars (stmt .test )
183
- if (isinstance (stmt , ast .Expr ) and hasattr (stmt , 'value' ) and
184
- isinstance (stmt .value , ast .Call )):
194
+ if (
195
+ isinstance (stmt , ast .Expr )
196
+ and hasattr (stmt , "value" )
197
+ and isinstance (stmt .value , ast .Call )
198
+ ):
185
199
f = stmt .value .func
186
- if f .id == ' print' :
200
+ if f .id == " print" :
187
201
return live_out
188
202
if isinstance (stmt , ast .For ):
189
203
# Analysis assumes loop may execute zero times. Results can be improved
@@ -201,7 +215,10 @@ def visit(stmt, live_out):
201
215
return used_inside_loop | used_in_loop_header | live_out
202
216
if isinstance (stmt , ast .Break ):
203
217
return live_out
204
- raise ValueError (DebugInfo (stmt , converter ).msg (
205
- f"Unsupported statement type { type (stmt )!r} ." ))
218
+ raise ValueError (
219
+ debuginfo .DebugInfo (stmt , converter ).msg (
220
+ f"Unsupported statement type { type (stmt )!r} ."
221
+ )
222
+ )
206
223
207
224
return visitBlock (stmts , set ())
0 commit comments