Skip to content

Commit 825b20e

Browse files
committed
more compat with py38
1 parent 093304a commit 825b20e

File tree

9 files changed

+40
-27
lines changed

9 files changed

+40
-27
lines changed

mlir/extras/ast/canonicalize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from dis import findlinestarts
99
from opcode import opmap
1010
from types import CodeType
11+
from typing import List, Union
1112

13+
import astunparse
1214
from bytecode import ConcreteBytecode
13-
from typing import List, Union
1415

1516
from ..ast.util import get_module_cst, copy_func
1617

@@ -33,14 +34,14 @@ def transform_func(f, *transformer_ctors: type(Transformer)):
3334
module = get_module_cst(f)
3435
context = types.SimpleNamespace()
3536
for transformer_ctor in transformer_ctors:
36-
orig_code = ast.unparse(module)
37+
orig_code = astunparse.unparse(module)
3738
func_node = module.body[0]
3839
replace = transformer_ctor(
3940
context=context, first_lineno=f.__code__.co_firstlineno - 1
4041
)
4142
logger.debug("[transformer] %s", replace.__class__.__name__)
4243
func_node = replace.generic_visit(func_node)
43-
new_code = ast.unparse(func_node)
44+
new_code = astunparse.unparse(func_node)
4445

4546
diff = list(
4647
difflib.unified_diff(

mlir/extras/dialects/ext/func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _is_decl(self):
163163
return self.body_builder.__code__.co_code == b"\x97\x00y\x00"
164164
elif sys.version_info.minor == 11:
165165
return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00"
166-
elif sys.version_info.minor == 10:
166+
elif sys.version_info.minor in {8, 9, 10}:
167167
return self.body_builder.__code__.co_code == b"d\x00S\x00"
168168
else:
169169
raise NotImplementedError(f"{sys.version_info.minor} not supported.")

mlir/extras/dialects/ext/scf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ class InsertEmptyYield(StrictTransformer):
455455
def visit_If(self, updated_node: ast.If) -> ast.If:
456456
updated_node = self.generic_visit(updated_node)
457457

458-
new_yield = ast.Expr(ast.Yield())
458+
new_yield = ast.Expr(ast.Yield(value=None))
459459
if not is_yield(updated_node.body[-1]):
460460
updated_node.body = append_hidden_node(
461461
updated_node.body, deepcopy(new_yield)
@@ -469,7 +469,7 @@ def visit_If(self, updated_node: ast.If) -> ast.If:
469469

470470
def visit_For(self, updated_node: ast.For) -> ast.For:
471471
updated_node = self.generic_visit(updated_node)
472-
new_yield = ast.Expr(ast.Yield())
472+
new_yield = ast.Expr(ast.Yield(value=None))
473473
if not is_yield(updated_node.body[-1]):
474474
updated_node.body = append_hidden_node(updated_node.body, new_yield)
475475
return updated_node
@@ -536,7 +536,7 @@ def visit_While(self, updated_node: ast.While) -> List[ast.AST]:
536536
next.__name__,
537537
[
538538
ast.Name(f"w_{updated_node.lineno}", ctx=ast.Load()),
539-
ast.Constant(False),
539+
ast.Constant(False, kind="bool"),
540540
],
541541
)
542542
next_ = ast.fix_missing_locations(ast.copy_location(next_, updated_node))
@@ -603,7 +603,8 @@ def visit_If(self, updated_node: ast.If) -> Union[ast.With, List[ast.With]]:
603603
if updated_node.orelse:
604604
if_op_name = ast.Name(f"__if_op__{updated_node.lineno}", ctx=ast.Load())
605605
withitem = ast.withitem(
606-
context_expr=ast_call(else_ctx_manager.__name__, args=[if_op_name])
606+
context_expr=ast_call(else_ctx_manager.__name__, args=[if_op_name]),
607+
optional_vars=None,
607608
)
608609
else_with = ast.With(items=[withitem])
609610
if is_elif:

mlir/extras/testing/testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def filecheck(correct: str, module):
2424
# try to find using which
2525
if not filecheck_path.exists():
2626
filecheck_path = shutil.which(filecheck_name)
27-
assert Path(filecheck_path).exists() is not None, "couldn't find FileCheck"
27+
assert (
28+
filecheck_path is not None and Path(filecheck_path).exists() is not None
29+
), "couldn't find FileCheck"
2830

2931
correct = "\n".join(filter(None, correct.splitlines()))
3032
correct = dedent(correct)

mlir/extras/util.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
TypeID = object
4040

4141

42+
def is_relative_to(self, other):
43+
return other == self or other in self.parents
44+
45+
4246
def get_user_code_loc(user_base: Optional[Path] = None):
4347
from .. import extras
4448

@@ -52,20 +56,18 @@ def get_user_code_loc(user_base: Optional[Path] = None):
5256
user_base = Path(prev_frame.f_code.co_filename)
5357

5458
while prev_frame.f_back and (
55-
Path(prev_frame.f_code.co_filename).is_relative_to(mlir_extras_root_path)
56-
or Path(prev_frame.f_code.co_filename).is_relative_to(sys.prefix)
57-
or Path(prev_frame.f_code.co_filename).is_relative_to(user_base)
59+
is_relative_to(Path(prev_frame.f_code.co_filename), mlir_extras_root_path)
60+
or is_relative_to(Path(prev_frame.f_code.co_filename), sys.prefix)
61+
or is_relative_to(Path(prev_frame.f_code.co_filename), user_base)
5862
):
5963
prev_frame = prev_frame.f_back
6064
frame_info = inspect.getframeinfo(prev_frame)
6165
if sys.version_info.minor >= 11:
6266
return Location.file(
6367
frame_info.filename, frame_info.lineno, frame_info.positions.col_offset
6468
)
65-
elif sys.version_info.minor == 10:
66-
return Location.file(frame_info.filename, frame_info.lineno, col=0)
6769
else:
68-
raise NotImplementedError(f"{sys.version_info.minor} not supported.")
70+
return Location.file(frame_info.filename, frame_info.lineno, col=0)
6971

7072

7173
@contextlib.contextmanager

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ PyYAML
33
black
44
bytecode
55
inflection
6-
numpy
6+
numpy
7+
astunparse

tests/test_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def demo_fun1():
4646
assert demo_fun1.__code__.co_code == b"\x97\x00y\x00"
4747
elif sys.version_info.minor == 11:
4848
assert demo_fun1.__code__.co_code == b"\x97\x00d\x00S\x00"
49-
elif sys.version_info.minor == 10:
49+
elif sys.version_info.minor in {8, 9, 10}:
5050
assert demo_fun1.__code__.co_code == b"d\x00S\x00"
5151
else:
5252
raise NotImplementedError(f"{sys.version_info.minor} not supported.")

tests/test_gpu.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from textwrap import dedent
23

34
import pytest
@@ -138,6 +139,7 @@ def mat_product_kernel(
138139
filecheck(correct, ctx.module)
139140

140141

142+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
141143
def test_class_call(ctx: MLIRContext):
142144
scale = 1
143145
M, N, K = 4 * scale, 16 * scale, 8 * scale
@@ -163,7 +165,10 @@ def mat_product_kernel(
163165
b = alloc((N, K), T.f32())
164166
c = alloc((M, K), T.f32())
165167

166-
MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](a, b, c)
168+
# this is to avoid python 3.8 parser
169+
eval(
170+
"MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](a, b, c)"
171+
)
167172

168173
correct = dedent(
169174
"""\
@@ -196,6 +201,7 @@ def mat_product_kernel(
196201
filecheck(correct, ctx.module)
197202

198203

204+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
199205
def test_class_call_from_func(ctx: MLIRContext):
200206
scale = 1
201207
M, N, K = 4 * scale, 16 * scale, 8 * scale
@@ -227,8 +233,9 @@ def main():
227233
b = alloc((N, K), T.f32())
228234
c = alloc((M, K), T.f32())
229235

230-
MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](
231-
a, b, c
236+
MyClass1
237+
eval(
238+
"MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](a, b, c)"
232239
)
233240

234241
ctx.module.operation.verify()
@@ -267,6 +274,7 @@ def main():
267274
filecheck(correct, ctx.module)
268275

269276

277+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
270278
def test_async_object(ctx: MLIRContext):
271279
scale = 1
272280
M, N, K = 4 * scale, 16 * scale, 8 * scale
@@ -300,12 +308,9 @@ def main():
300308

301309
w = wait()
302310
stream = mlir_zero(llvm_ptr_t())
303-
MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](
304-
a,
305-
b,
306-
c,
307-
async_dependencies=[w],
308-
stream=stream,
311+
MyClass1
312+
eval(
313+
"MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](a, b, c, async_dependencies=[w], stream=stream)"
309314
)
310315

311316
correct = dedent(

tests/test_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ast
22
import sys
33
from textwrap import dedent
4+
from typing import Tuple
45

56
import astpretty
67
import pytest
@@ -22,7 +23,7 @@
2223
pytest.mark.usefixtures("ctx")
2324

2425

25-
def _fields(n: ast.AST, show_offsets: bool = True) -> tuple[str, ...]:
26+
def _fields(n: ast.AST, show_offsets: bool = True) -> Tuple[str, ...]:
2627
strip = {"type_ignores", "decorator_list", "type_comment", "ctx", "kind"}
2728
fields = tuple(f for f in n._fields if f not in strip)
2829
attributes = ("lineno",) if "lineno" in n._attributes else ()

0 commit comments

Comments
 (0)