Skip to content

Commit 992f554

Browse files
titaiwangmsUbuntu
and
Ubuntu
authored
[Lint] Apply lint (#126)
Should be merged after #120 fixes #125 1. add `auto-formatter` **lint.sh** (feel free to add on others) 2. flake8, black, and isort formatting is clean now 3. pylint and mypy is a lot of work to address - need everyone address these step by step (also take them out from style.sh, and make style_optional.sh CI test for them.) 4. relative path addressed (import module instead of class and function) 5. We should enable Lint/Enforce as mandatory 6. Move debuginfo stand alone as it caused circular import. NOTE: I hope this one can be quickly merged, as there are PRs being merged constantly which means endless merge conflict for this one. Co-authored-by: Ubuntu <titaiwang@titaiwanglinuxcpudev.y3zdd0j2xrqelnmezcgpqgmnte.jx.internal.cloudapp.net>
1 parent a48de87 commit 992f554

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2148
-1421
lines changed

README.md

-1
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,3 @@ optional check:
155155
```
156156

157157
NOTE: mypy and pylint needs to be manually address
158-

onnxscript/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# --------------------------------------------------------------------------
55

66
import sys
7-
from .main import script, export_onnx_lib, OnnxFunction
7+
88
from .backend.onnx_export import export2python as proto2python
9+
from .main import export_onnx_lib, script
10+
from .values import OnnxFunction
911

1012
if sys.version_info[0:2] >= (3, 8):
1113
import importlib.metadata as importlib_metadata

onnxscript/__main__.py

+48-27
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
import os
77
from typing import Optional
8-
import onnx
9-
import onnx.helper as helper
8+
109
import click
11-
from onnxscript.converter import Converter
12-
from onnxscript.backend.onnx_export import export2python
10+
import onnx
11+
from onnx import helper
12+
13+
from onnxscript import converter
14+
from onnxscript.backend import onnx_export
1315

1416

1517
@click.group()
@@ -18,18 +20,18 @@ def cli():
1820

1921

2022
def convert_file(script):
21-
converter = Converter()
22-
return converter.convert_file(script)
23+
convert = converter.Converter()
24+
return convert.convert_file(script)
2325

2426

2527
def to_single_model_proto(model, input_py_file: str, output_onnx_file: Optional[str] = None):
26-
if (not output_onnx_file):
27-
prefix, ext = os.path.splitext(input_py_file)
28+
if not output_onnx_file:
29+
prefix, _ = os.path.splitext(input_py_file)
2830
output_onnx_file = prefix + ".onnx"
2931

3032
fnlist = convert_file(input_py_file)
3133

32-
if (not fnlist):
34+
if not fnlist:
3335
print("No functions in input.")
3436
return
3537

@@ -47,8 +49,9 @@ def to_single_model_proto(model, input_py_file: str, output_onnx_file: Optional[
4749
model = onnx.helper.make_model(
4850
graph,
4951
functions=[f.to_function_proto() for f in fnlist],
50-
producer_name='p2o',
51-
opset_imports=[onnx.helper.make_opsetid("", 15)])
52+
producer_name="p2o",
53+
opset_imports=[onnx.helper.make_opsetid("", 15)],
54+
)
5255

5356
# TODO: add options for user to specify whether to check generated model
5457
# model = onnx.shape_inference.infer_shapes(model)
@@ -71,34 +74,52 @@ def print_ir_function(f):
7174

7275

7376
@cli.command()
74-
@click.option('--fmt', type=click.Choice(['text', 'model', 'lib'], case_sensitive=False),
75-
help="Translate input to a single ModelProto ('model'), "
76-
"into a LibProto ('lib'), "
77-
"or into text 'text').")
78-
@click.option('name', '--name', envvar='PATHS', multiple=True, type=click.Path(),
79-
help="File or files to convert.")
80-
def translate(fmt="text", name=None):
77+
@click.option(
78+
"--fmt",
79+
type=click.Choice(["text", "model", "lib"], case_sensitive=False),
80+
help="Translate input to a single ModelProto ('model'), "
81+
"into a LibProto ('lib'), "
82+
"or into text 'text').",
83+
)
84+
@click.option(
85+
"name",
86+
"--name",
87+
envvar="PATHS",
88+
multiple=True,
89+
type=click.Path(),
90+
help="File or files to convert.",
91+
)
92+
def translate(fmt="text", names=None):
8193
"""Translate a file or many files into a ModelProto, a LibProto or text."""
8294
if fmt == "text":
83-
for name in name:
95+
for name in names:
8496
to_text(name)
8597
else:
86-
for name in name:
98+
for name in names:
8799
to_single_model_proto(fmt == "model", name)
88100

89101

90102
@cli.command()
91-
@click.option('name', '--name', envvar='PATHS', multiple=False, type=click.Path(),
92-
help="filename to convert")
93-
@click.option("--op", is_flag=True, default=False,
94-
help="converts a numerical operator into op.Add (False) or keep it (True)")
95-
@click.option("--rename", is_flag=True, default=False,
96-
help="to use shorter variable name")
103+
@click.option(
104+
"name",
105+
"--name",
106+
envvar="PATHS",
107+
multiple=False,
108+
type=click.Path(),
109+
help="filename to convert",
110+
)
111+
@click.option(
112+
"--op",
113+
is_flag=True,
114+
default=False,
115+
help="converts a numerical operator into op.Add (False) or keep it (True)",
116+
)
117+
@click.option("--rename", is_flag=True, default=False, help="to use shorter variable name")
97118
def onnx2script(name, op=False, rename=False):
98119
"""Exports an onnx graph to a script in following onnx-script syntax.
99120
The result is printed on the standard output.
100121
"""
101-
code = export2python(name, use_operators=op, rename=rename)
122+
code = onnx_export.export2python(name, use_operators=op, rename=rename)
102123
print(code)
103124

104125

onnxscript/analysis.py

+44-27
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,22 @@
44
# --------------------------------------------------------------------------
55

66
import ast
7-
from .values import DebugInfo
7+
8+
from onnxscript import debuginfo
89

910

1011
def get_loop_var(for_stmt, converter):
1112
if not isinstance(for_stmt.target, ast.Name):
12-
raise ValueError(DebugInfo(for_stmt, converter).msg(
13-
"For loop target must be a single variable."))
13+
raise ValueError(
14+
debuginfo.DebugInfo(for_stmt, converter).msg(
15+
"For loop target must be a single variable."
16+
)
17+
)
1418
return for_stmt.target.id
1519

1620

1721
def used_vars(expr):
18-
''' Return set of all variables used in an expression.'''
22+
"""Return set of all variables used in an expression."""
1923
if isinstance(expr, ast.Name):
2024
return set([expr.id])
2125
if isinstance(expr, ast.Call):
@@ -31,22 +35,24 @@ def used_vars(expr):
3135

3236

3337
def local_defs(lhs):
34-
'''Utility function to return set of assigned/defined
35-
variables in the lhs of an assignment statement.'''
38+
"""Utility function to return set of assigned/defined
39+
variables in the lhs of an assignment statement."""
40+
3641
def get_id(e):
3742
assert isinstance(e, ast.Name), "Only simple assignments supported."
3843
return e.id
3944

40-
if (isinstance(lhs, ast.Tuple)):
41-
return set([get_id(x) for x in lhs.elts])
42-
return set([get_id(lhs)])
45+
if isinstance(lhs, ast.Tuple):
46+
return {get_id(x) for x in lhs.elts}
47+
return {get_id(lhs)}
4348

4449

4550
def defs(stmt):
46-
'''
51+
"""
4752
Return the set of all variables that may be defined (assigned to) in an
4853
execution of input stmt.
49-
'''
54+
"""
55+
5056
def block_defs(block):
5157
result = set()
5258
for s in block:
@@ -66,7 +72,7 @@ def block_defs(block):
6672
if isinstance(stmt, ast.Break):
6773
return set()
6874
try:
69-
if stmt.value.func.id == 'print':
75+
if stmt.value.func.id == "print":
7076
# Any call to print function are ignored.
7177
return set()
7278
except (TypeError, AttributeError):
@@ -75,11 +81,12 @@ def block_defs(block):
7581

7682

7783
def do_liveness_analysis(fun, converter):
78-
'''
84+
"""
7985
Perform liveness analysis of the given function-ast. The results of the
8086
analysis are stored directly with each statement-ast `s` as attributes `s.live_in`
8187
and `s.live_out`.
82-
'''
88+
"""
89+
8390
def visit(stmt, live_out):
8491
stmt.live_out = live_out
8592
live = do_visit(stmt, live_out)
@@ -124,22 +131,25 @@ def visitBlock(block, live_out):
124131
# Break statements in the middle of the loop, however, will require
125132
# a generalization.
126133
return live_out
127-
if isinstance(stmt, ast.Expr) and hasattr(stmt, 'value'):
134+
if isinstance(stmt, ast.Expr) and hasattr(stmt, "value"):
128135
# docstring
129-
if hasattr(stmt.value, 'value') and isinstance(stmt.value.value, str):
136+
if hasattr(stmt.value, "value") and isinstance(stmt.value.value, str):
130137
# python 3.8+
131138
return live_out
132-
if hasattr(stmt.value, 's') and isinstance(stmt.value.s, str):
139+
if hasattr(stmt.value, "s") and isinstance(stmt.value.s, str):
133140
# python 3.7
134141
return live_out
135142
try:
136-
if stmt.value.func.id == 'print':
143+
if stmt.value.func.id == "print":
137144
# Any call to print function are ignored.
138145
return live_out
139146
except (TypeError, AttributeError):
140147
pass
141-
raise ValueError(DebugInfo(stmt, converter).msg(
142-
f"Unsupported statement type {type(stmt)!r}."))
148+
raise ValueError(
149+
debuginfo.DebugInfo(stmt, converter).msg(
150+
f"Unsupported statement type {type(stmt)!r}."
151+
)
152+
)
143153

144154
assert isinstance(fun, ast.FunctionDef)
145155
live = set()
@@ -148,7 +158,7 @@ def visitBlock(block, live_out):
148158

149159

150160
def exposed_uses(stmts, converter):
151-
'''
161+
"""
152162
Return the set of variables that are used before being defined by given block.
153163
In essence, this identifies the "inputs" to a given code-block.
154164
For example, consider the following code-block:
@@ -163,7 +173,8 @@ def exposed_uses(stmts, converter):
163173
the block. Even though the value of y is used within the block, it is assigned
164174
a value before it is used. However, in contrast, the incoming value of x is used
165175
(in the first statement). Hence x is included in the exposed_uses.
166-
'''
176+
"""
177+
167178
def visitBlock(block, live_out):
168179
for stmt in reversed(block):
169180
live_out = visit(stmt, live_out)
@@ -180,10 +191,13 @@ def visit(stmt, live_out):
180191
live1 = visitBlock(stmt.body, live_out)
181192
live2 = visitBlock(stmt.orelse, live_out)
182193
return (live1 | live2) | used_vars(stmt.test)
183-
if (isinstance(stmt, ast.Expr) and hasattr(stmt, 'value') and
184-
isinstance(stmt.value, ast.Call)):
194+
if (
195+
isinstance(stmt, ast.Expr)
196+
and hasattr(stmt, "value")
197+
and isinstance(stmt.value, ast.Call)
198+
):
185199
f = stmt.value.func
186-
if f.id == 'print':
200+
if f.id == "print":
187201
return live_out
188202
if isinstance(stmt, ast.For):
189203
# Analysis assumes loop may execute zero times. Results can be improved
@@ -201,7 +215,10 @@ def visit(stmt, live_out):
201215
return used_inside_loop | used_in_loop_header | live_out
202216
if isinstance(stmt, ast.Break):
203217
return live_out
204-
raise ValueError(DebugInfo(stmt, converter).msg(
205-
f"Unsupported statement type {type(stmt)!r}."))
218+
raise ValueError(
219+
debuginfo.DebugInfo(stmt, converter).msg(
220+
f"Unsupported statement type {type(stmt)!r}."
221+
)
222+
)
206223

207224
return visitBlock(stmts, set())

0 commit comments

Comments
 (0)