Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 51 additions & 15 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ on:

jobs:

test-all:
test-mlir-bindings:

runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: [ ubuntu-22.04, macos-11, windows-2022 ]
py_version: [ "3.10", "3.11" ]
py_version: [ "3.10", "3.11", "3.12" ]

steps:
- name: Checkout
Expand All @@ -34,14 +34,53 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py_version }}
allow-prereleases: true

- name: Install and configure
shell: bash
run: |
pip install .[test,mlir] -v -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines

HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[jax] -v

- name: Test
shell: bash
run: |
if [ ${{ matrix.os }} == 'windows-2022' ]; then
pytest -s --ignore-glob=*test_other_hosts* tests
else
pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests
fi

- name: Test mwe
shell: bash
run: |
python examples/mwe.py

test-other-host-bindings:

runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: [ ubuntu-22.04, macos-11, windows-2022 ]
py_version: [ "3.10", "3.11" ]

steps:
- name: Checkout
uses: actions/checkout@v2

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py_version }}
allow-prereleases: true

- name: Install and configure
shell: bash
run: |
export PIP_FIND_LINKS=https://makslevental.github.io/wheels
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v
jaxlib-mlir-python-utils-generate-all-upstream-trampolines

pip install aie -f https://github.com/Xilinx/mlir-aie/releases/expanded_assets/latest-wheels --no-index
Expand All @@ -53,16 +92,11 @@ jobs:
shell: bash
run: |
if [ ${{ matrix.os }} == 'windows-2022' ]; then
pytest -s tests
pytest -s tests/test_other_hosts.py
else
pytest --capture=tee-sys tests
pytest --capture=tee-sys tests/test_other_hosts.py
fi

- name: Test mwe
shell: bash
run: |
python examples/mwe.py

test-jupyter:

runs-on: ${{ matrix.os }}
Expand All @@ -81,6 +115,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py_version }}
allow-prereleases: true

- name: Run notebook
shell: bash
Expand All @@ -98,7 +133,7 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-22.04 ]
py_version: [ "3.10", "3.11" ]
py_version: [ "3.10", "3.11", "3.12" ]

steps:
- name: Checkout
Expand All @@ -120,16 +155,17 @@ jobs:
bash miniconda.sh -b -u -p /root/miniconda3
eval "$(/root/miniconda3/bin/conda shell.bash hook)"
conda init
conda install -q -y python=${{ matrix.py_version }}

run: |

eval "$(/root/miniconda3/bin/conda shell.bash hook)"
conda create -n env -q -y -c conda-forge/label/python_rc python=${{ matrix.py_version }}
conda activate env

cd /workspace

pip install .[test,mlir] -f https://makslevental.github.io/wheels
pip install -q .[test,mlir] -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines

pytest --capture=tee-sys --ignore-glob=*test_smoke* tests
pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests
python examples/mwe.py
2 changes: 2 additions & 0 deletions mlir/utils/_configuration/generate_trampolines.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def generate_op_trampoline(op_class):
args=args,
body=body,
decorator_list=decorator_list,
type_params=[],
)
ast.fix_missing_locations(n)
return n
Expand Down Expand Up @@ -323,6 +324,7 @@ def generate_linalg(mod_path):
),
body=body,
decorator_list=[],
type_params=[],
)
ast.fix_missing_locations(n)
functions.append(n)
Expand Down
2 changes: 1 addition & 1 deletion mlir/utils/_configuration/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def add_file_to_sources_txt_file(file_path: Path):

assert file_path.exists(), f"file being added doesn't exist at {file_path}"
relative_file_path = Path(package) / file_path.relative_to(package_root_path)
if dist._read_files_egginfo() is not None:
if hasattr(dist, "_read_files_egginfo") and dist._read_files_egginfo() is not None:
with open(dist._path / "SOURCES.txt", "a") as sources_file:
sources_file.write(f"\n{relative_file_path}")
if dist._read_files_distinfo():
Expand Down
14 changes: 9 additions & 5 deletions mlir/utils/dialects/ext/func.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import inspect
import sys
from typing import Union, Optional

from ...meta import make_maybe_no_args_decorator, maybe_cast
from ...util import get_result_or_results, get_user_code_loc
from ....dialects.func import FuncOp, ReturnOp, CallOp
from ....ir import (
InsertionPoint,
Expand All @@ -12,9 +15,6 @@
Value,
)

from ...util import get_result_or_results, get_user_code_loc, is_311
from ...meta import make_maybe_no_args_decorator, maybe_cast


def call(
callee_or_results: Union[FuncOp, list[Type]],
Expand Down Expand Up @@ -145,10 +145,14 @@ def __init__(

def _is_decl(self):
# magic constant found from looking at the code for an empty fn
if is_311():
if sys.version_info.minor == 12:
return self.body_builder.__code__.co_code == b"\x97\x00y\x00"
elif sys.version_info.minor == 11:
return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00"
else:
elif sys.version_info.minor == 10:
return self.body_builder.__code__.co_code == b"d\x00S\x00"
else:
raise NotImplementedError(f"{sys.version_info.minor} not supported.")

def __str__(self):
return str(f"{self.__class__} {self.__dict__}")
Expand Down
48 changes: 15 additions & 33 deletions mlir/utils/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@
from typing import Optional, Sequence, Union

from bytecode import ConcreteBytecode, ConcreteInstr

from ... import types as T
from ...ast.canonicalize import (
StrictTransformer,
Canonicalizer,
BytecodePatcher,
OpCode,
)
from ...ast.util import ast_call, set_lineno
from ...dialects.ext.arith import constant, index_cast
from ...dialects.ext.gpu import get_device_mapping_array_attr
from ...dialects.scf import yield_ as yield__, reduce_return, condition
from ...meta import region_adder, region_op, maybe_cast
from ...util import get_result_or_results, get_user_code_loc
from ....dialects._ods_common import get_op_results_or_values, get_default_loc_context
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type
from ....dialects.scf import (
Expand All @@ -28,20 +42,6 @@
Attribute,
)

from ... import types as T
from ...ast.canonicalize import (
StrictTransformer,
Canonicalizer,
BytecodePatcher,
OpCode,
)
from ...ast.util import ast_call, set_lineno
from ...dialects.ext.arith import constant, index_cast
from ...dialects.ext.gpu import get_device_mapping_array_attr
from ...dialects.scf import yield_ as yield__, reduce_return, condition
from ...meta import region_adder, region_op, maybe_cast
from ...util import get_result_or_results, get_user_code_loc, is_311

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -633,25 +633,7 @@ def visit_If(self, updated_node: ast.If) -> ast.With | list[ast.With, ast.With]:

class RemoveJumpsAndInsertGlobals(BytecodePatcher):
def patch_bytecode(self, code: ConcreteBytecode, f):
early_returns = []
for i, c in enumerate(code):
c: ConcreteInstr
if c.opcode == int(OpCode.RETURN_VALUE):
early_returns.append(i)

if c.opcode in {
# this is the first test condition jump from python <= 3.10
# "POP_JUMP_IF_FALSE",
# this is the test condition jump from python >= 3.11
int(OpCode.POP_JUMP_FORWARD_IF_FALSE)
if is_311()
else int(OpCode.POP_JUMP_IF_FALSE),
}:
code[i] = ConcreteInstr(
str(OpCode.POP_TOP), lineno=c.lineno, location=c.location
)

# TODO(max): this is bad
# TODO(max): this is bad and should be in the closure rather than as a global
f.__globals__[yield_.__name__] = yield_
f.__globals__[if_ctx_manager.__name__] = if_ctx_manager
f.__globals__[else_ctx_manager.__name__] = else_ctx_manager
Expand Down
10 changes: 4 additions & 6 deletions mlir/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def get_result_or_results(
)


def is_311():
return sys.version_info.minor > 10


def get_user_code_loc(user_base: Optional[Path] = None):
from .. import utils

Expand All @@ -54,12 +50,14 @@ def get_user_code_loc(user_base: Optional[Path] = None):
):
prev_frame = prev_frame.f_back
frame_info = inspect.getframeinfo(prev_frame)
if is_311():
if sys.version_info.minor >= 11:
return Location.file(
frame_info.filename, frame_info.lineno, frame_info.positions.col_offset
)
else:
elif sys.version_info.minor == 10:
return Location.file(frame_info.filename, frame_info.lineno, col=0)
else:
raise NotImplementedError(f"{sys.version_info.minor} not supported.")


@contextlib.contextmanager
Expand Down
12 changes: 8 additions & 4 deletions tests/test_func.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import inspect
import sys
from textwrap import dedent

import pytest

import mlir.utils.types as T
from mlir.utils.dialects.ext.arith import constant
from mlir.utils.dialects.ext.func import func

# noinspection PyUnresolvedReferences
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
import mlir.utils.types as T
from mlir.utils.util import is_311

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")
Expand Down Expand Up @@ -41,10 +41,14 @@ def test_declare_byte_rep(ctx: MLIRContext):
def demo_fun1():
...

if is_311():
if sys.version_info.minor == 12:
assert demo_fun1.__code__.co_code == b"\x97\x00y\x00"
elif sys.version_info.minor == 11:
assert demo_fun1.__code__.co_code == b"\x97\x00d\x00S\x00"
else:
elif sys.version_info.minor == 10:
assert demo_fun1.__code__.co_code == b"d\x00S\x00"
else:
raise NotImplementedError(f"{sys.version_info.minor} not supported.")


def test_declare(ctx: MLIRContext):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_location_tracking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from os import sep
from pathlib import Path
from textwrap import dedent
Expand All @@ -13,7 +14,6 @@

# noinspection PyUnresolvedReferences
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
from mlir.utils.util import is_311

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")
Expand All @@ -27,7 +27,7 @@ def get_asm(operation):
)


@pytest.mark.skipif(not is_311(), reason="310 doesn't have col numbers")
@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest")
def test_if_replace_yield_5(ctx: MLIRContext):
@canonicalize(using=canonicalizer)
def iffoo():
Expand Down Expand Up @@ -72,7 +72,7 @@ def iffoo():
filecheck(correct, asm)


@pytest.mark.skipif(not is_311(), reason="310 doesn't have col numbers")
@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest")
def test_block_args(ctx: MLIRContext):
one = constant(1, T.index)
two = constant(2, T.index)
Expand Down
Loading