Skip to content

Commit d80dfd3

Browse files
committed
add py312 tests
1 parent e97d5c0 commit d80dfd3

File tree

12 files changed

+342
-159
lines changed

12 files changed

+342
-159
lines changed

.github/workflows/test.yml

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ on:
1616

1717
jobs:
1818

19-
test-all:
19+
test-mlir-bindings:
2020

2121
runs-on: ${{ matrix.os }}
2222

2323
strategy:
2424
fail-fast: false
2525
matrix:
2626
os: [ ubuntu-22.04, macos-11, windows-2022 ]
27-
py_version: [ "3.10", "3.11" ]
27+
py_version: [ "3.10", "3.11", "3.12" ]
2828

2929
steps:
3030
- name: Checkout
@@ -34,14 +34,53 @@ jobs:
3434
uses: actions/setup-python@v4
3535
with:
3636
python-version: ${{ matrix.py_version }}
37+
allow-prereleases: true
3738

3839
- name: Install and configure
3940
shell: bash
4041
run: |
4142
pip install .[test,mlir] -v -f https://makslevental.github.io/wheels
4243
mlir-python-utils-generate-all-upstream-trampolines
43-
44-
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[jax] -v
44+
45+
- name: Test
46+
shell: bash
47+
run: |
48+
if [ ${{ matrix.os }} == 'windows-2022' ]; then
49+
pytest -s --ignore-glob=*test_other_hosts* tests
50+
else
51+
pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests
52+
fi
53+
54+
- name: Test mwe
55+
shell: bash
56+
run: |
57+
python examples/mwe.py
58+
59+
test-other-host-bindings:
60+
61+
runs-on: ${{ matrix.os }}
62+
63+
strategy:
64+
fail-fast: false
65+
matrix:
66+
os: [ ubuntu-22.04, macos-11, windows-2022 ]
67+
py_version: [ "3.10", "3.11" ]
68+
69+
steps:
70+
- name: Checkout
71+
uses: actions/checkout@v2
72+
73+
- name: Setup Python
74+
uses: actions/setup-python@v4
75+
with:
76+
python-version: ${{ matrix.py_version }}
77+
allow-prereleases: true
78+
79+
- name: Install and configure
80+
shell: bash
81+
run: |
82+
export PIP_FIND_LINKS=https://makslevental.github.io/wheels
83+
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v
4584
jaxlib-mlir-python-utils-generate-all-upstream-trampolines
4685
4786
pip install aie -f https://github.com/Xilinx/mlir-aie/releases/expanded_assets/latest-wheels --no-index
@@ -53,16 +92,11 @@ jobs:
5392
shell: bash
5493
run: |
5594
if [ ${{ matrix.os }} == 'windows-2022' ]; then
56-
pytest -s tests
95+
pytest -s tests/test_other_hosts.py
5796
else
58-
pytest --capture=tee-sys tests
97+
pytest --capture=tee-sys tests/test_other_hosts.py
5998
fi
6099
61-
- name: Test mwe
62-
shell: bash
63-
run: |
64-
python examples/mwe.py
65-
66100
test-jupyter:
67101

68102
runs-on: ${{ matrix.os }}
@@ -81,6 +115,7 @@ jobs:
81115
uses: actions/setup-python@v4
82116
with:
83117
python-version: ${{ matrix.py_version }}
118+
allow-prereleases: true
84119

85120
- name: Run notebook
86121
shell: bash
@@ -98,7 +133,7 @@ jobs:
98133
fail-fast: false
99134
matrix:
100135
os: [ ubuntu-22.04 ]
101-
py_version: [ "3.10", "3.11" ]
136+
py_version: [ "3.10", "3.11", "3.12" ]
102137

103138
steps:
104139
- name: Checkout
@@ -120,16 +155,17 @@ jobs:
120155
bash miniconda.sh -b -u -p /root/miniconda3
121156
eval "$(/root/miniconda3/bin/conda shell.bash hook)"
122157
conda init
123-
conda install -q -y python=${{ matrix.py_version }}
124158
125159
run: |
126160
127161
eval "$(/root/miniconda3/bin/conda shell.bash hook)"
162+
conda create -n env -q -y -c conda-forge/label/python_rc python=${{ matrix.py_version }}
163+
conda activate env
128164
129165
cd /workspace
130166
131-
pip install .[test,mlir] -f https://makslevental.github.io/wheels
167+
pip install -q .[test,mlir] -f https://makslevental.github.io/wheels
132168
mlir-python-utils-generate-all-upstream-trampolines
133169
134-
pytest --capture=tee-sys --ignore-glob=*test_smoke* tests
170+
pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests
135171
python examples/mwe.py

mlir/utils/_configuration/generate_trampolines.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def generate_op_trampoline(op_class):
126126
args=args,
127127
body=body,
128128
decorator_list=decorator_list,
129+
type_params=[],
129130
)
130131
ast.fix_missing_locations(n)
131132
return n
@@ -323,6 +324,7 @@ def generate_linalg(mod_path):
323324
),
324325
body=body,
325326
decorator_list=[],
327+
type_params=[],
326328
)
327329
ast.fix_missing_locations(n)
328330
functions.append(n)

mlir/utils/_configuration/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def add_file_to_sources_txt_file(file_path: Path):
2424

2525
assert file_path.exists(), f"file being added doesn't exist at {file_path}"
2626
relative_file_path = Path(package) / file_path.relative_to(package_root_path)
27-
if dist._read_files_egginfo() is not None:
27+
if hasattr(dist, "_read_files_egginfo") and dist._read_files_egginfo() is not None:
2828
with open(dist._path / "SOURCES.txt", "a") as sources_file:
2929
sources_file.write(f"\n{relative_file_path}")
3030
if dist._read_files_distinfo():

mlir/utils/dialects/ext/func.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import inspect
2+
import sys
23
from typing import Union, Optional
34

5+
from ...meta import make_maybe_no_args_decorator, maybe_cast
6+
from ...util import get_result_or_results, get_user_code_loc
47
from ....dialects.func import FuncOp, ReturnOp, CallOp
58
from ....ir import (
69
InsertionPoint,
@@ -12,9 +15,6 @@
1215
Value,
1316
)
1417

15-
from ...util import get_result_or_results, get_user_code_loc, is_311
16-
from ...meta import make_maybe_no_args_decorator, maybe_cast
17-
1818

1919
def call(
2020
callee_or_results: Union[FuncOp, list[Type]],
@@ -145,10 +145,14 @@ def __init__(
145145

146146
def _is_decl(self):
147147
# magic constant found from looking at the code for an empty fn
148-
if is_311():
148+
if sys.version_info.minor == 12:
149+
return self.body_builder.__code__.co_code == b"\x97\x00y\x00"
150+
elif sys.version_info.minor == 11:
149151
return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00"
150-
else:
152+
elif sys.version_info.minor == 10:
151153
return self.body_builder.__code__.co_code == b"d\x00S\x00"
154+
else:
155+
raise NotImplementedError(f"{sys.version_info.minor} not supported.")
152156

153157
def __str__(self):
154158
return str(f"{self.__class__} {self.__dict__}")

mlir/utils/dialects/ext/scf.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@
55
from typing import Optional, Sequence, Union
66

77
from bytecode import ConcreteBytecode, ConcreteInstr
8+
9+
from ... import types as T
10+
from ...ast.canonicalize import (
11+
StrictTransformer,
12+
Canonicalizer,
13+
BytecodePatcher,
14+
OpCode,
15+
)
16+
from ...ast.util import ast_call, set_lineno
17+
from ...dialects.ext.arith import constant, index_cast
18+
from ...dialects.ext.gpu import get_device_mapping_array_attr
19+
from ...dialects.scf import yield_ as yield__, reduce_return, condition
20+
from ...meta import region_adder, region_op, maybe_cast
21+
from ...util import get_result_or_results, get_user_code_loc
822
from ....dialects._ods_common import get_op_results_or_values, get_default_loc_context
923
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type
1024
from ....dialects.scf import (
@@ -28,20 +42,6 @@
2842
Attribute,
2943
)
3044

31-
from ... import types as T
32-
from ...ast.canonicalize import (
33-
StrictTransformer,
34-
Canonicalizer,
35-
BytecodePatcher,
36-
OpCode,
37-
)
38-
from ...ast.util import ast_call, set_lineno
39-
from ...dialects.ext.arith import constant, index_cast
40-
from ...dialects.ext.gpu import get_device_mapping_array_attr
41-
from ...dialects.scf import yield_ as yield__, reduce_return, condition
42-
from ...meta import region_adder, region_op, maybe_cast
43-
from ...util import get_result_or_results, get_user_code_loc, is_311
44-
4545
logger = logging.getLogger(__name__)
4646

4747

@@ -633,25 +633,7 @@ def visit_If(self, updated_node: ast.If) -> ast.With | list[ast.With, ast.With]:
633633

634634
class RemoveJumpsAndInsertGlobals(BytecodePatcher):
635635
def patch_bytecode(self, code: ConcreteBytecode, f):
636-
early_returns = []
637-
for i, c in enumerate(code):
638-
c: ConcreteInstr
639-
if c.opcode == int(OpCode.RETURN_VALUE):
640-
early_returns.append(i)
641-
642-
if c.opcode in {
643-
# this is the first test condition jump from python <= 3.10
644-
# "POP_JUMP_IF_FALSE",
645-
# this is the test condition jump from python >= 3.11
646-
int(OpCode.POP_JUMP_FORWARD_IF_FALSE)
647-
if is_311()
648-
else int(OpCode.POP_JUMP_IF_FALSE),
649-
}:
650-
code[i] = ConcreteInstr(
651-
str(OpCode.POP_TOP), lineno=c.lineno, location=c.location
652-
)
653-
654-
# TODO(max): this is bad
636+
# TODO(max): this is bad and should be in the closure rather than as a global
655637
f.__globals__[yield_.__name__] = yield_
656638
f.__globals__[if_ctx_manager.__name__] = if_ctx_manager
657639
f.__globals__[else_ctx_manager.__name__] = else_ctx_manager

mlir/utils/util.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ def get_result_or_results(
3434
)
3535

3636

37-
def is_311():
38-
return sys.version_info.minor > 10
39-
40-
4137
def get_user_code_loc(user_base: Optional[Path] = None):
4238
from .. import utils
4339

@@ -54,12 +50,14 @@ def get_user_code_loc(user_base: Optional[Path] = None):
5450
):
5551
prev_frame = prev_frame.f_back
5652
frame_info = inspect.getframeinfo(prev_frame)
57-
if is_311():
53+
if sys.version_info.minor >= 11:
5854
return Location.file(
5955
frame_info.filename, frame_info.lineno, frame_info.positions.col_offset
6056
)
61-
else:
57+
elif sys.version_info.minor == 10:
6258
return Location.file(frame_info.filename, frame_info.lineno, col=0)
59+
else:
60+
raise NotImplementedError(f"{sys.version_info.minor} not supported.")
6361

6462

6563
@contextlib.contextmanager

tests/test_func.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import inspect
2+
import sys
23
from textwrap import dedent
34

45
import pytest
56

7+
import mlir.utils.types as T
68
from mlir.utils.dialects.ext.arith import constant
79
from mlir.utils.dialects.ext.func import func
810

911
# noinspection PyUnresolvedReferences
1012
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
11-
import mlir.utils.types as T
12-
from mlir.utils.util import is_311
1313

1414
# needed since the fix isn't defined here nor conftest.py
1515
pytest.mark.usefixtures("ctx")
@@ -41,10 +41,14 @@ def test_declare_byte_rep(ctx: MLIRContext):
4141
def demo_fun1():
4242
...
4343

44-
if is_311():
44+
if sys.version_info.minor == 12:
45+
assert demo_fun1.__code__.co_code == b"\x97\x00y\x00"
46+
elif sys.version_info.minor == 11:
4547
assert demo_fun1.__code__.co_code == b"\x97\x00d\x00S\x00"
46-
else:
48+
elif sys.version_info.minor == 10:
4749
assert demo_fun1.__code__.co_code == b"d\x00S\x00"
50+
else:
51+
raise NotImplementedError(f"{sys.version_info.minor} not supported.")
4852

4953

5054
def test_declare(ctx: MLIRContext):

tests/test_location_tracking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from os import sep
23
from pathlib import Path
34
from textwrap import dedent
@@ -13,7 +14,6 @@
1314

1415
# noinspection PyUnresolvedReferences
1516
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
16-
from mlir.utils.util import is_311
1717

1818
# needed since the fix isn't defined here nor conftest.py
1919
pytest.mark.usefixtures("ctx")
@@ -27,7 +27,7 @@ def get_asm(operation):
2727
)
2828

2929

30-
@pytest.mark.skipif(not is_311(), reason="310 doesn't have col numbers")
30+
@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest")
3131
def test_if_replace_yield_5(ctx: MLIRContext):
3232
@canonicalize(using=canonicalizer)
3333
def iffoo():
@@ -72,7 +72,7 @@ def iffoo():
7272
filecheck(correct, asm)
7373

7474

75-
@pytest.mark.skipif(not is_311(), reason="310 doesn't have col numbers")
75+
@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest")
7676
def test_block_args(ctx: MLIRContext):
7777
one = constant(1, T.index)
7878
two = constant(2, T.index)

0 commit comments

Comments
 (0)