Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mlir/extras/dialects/ext/memref.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from functools import cached_property
from functools import cached_property, reduce
from typing import Tuple, Sequence, Optional, Union

from ....ir import Type, Value, MemRefType, ShapedType, MLIRError
Expand Down Expand Up @@ -163,6 +163,11 @@ def has_rank(self) -> bool:
def shape(self) -> Tuple[int, ...]:
return tuple(self._shaped_type.shape)

@cached_property
def n_elements(self) -> int:
assert self.has_static_shape()
return reduce(lambda acc, v: acc * v, self._shaped_type.shape, 1)

@cached_property
def dtype(self) -> Type:
return self._shaped_type.element_type
Expand Down
7 changes: 6 additions & 1 deletion mlir/extras/dialects/ext/tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from dataclasses import dataclass
from functools import cached_property
from functools import cached_property, reduce
from typing import Union, Tuple, Sequence, Optional, Any

import numpy as np
Expand Down Expand Up @@ -132,6 +132,11 @@ def has_rank(self) -> bool:
def shape(self) -> Tuple[int, ...]:
return tuple(self._shaped_type.shape)

@cached_property
def n_elements(self) -> int:
assert self.has_static_shape()
return reduce(lambda acc, v: acc * v, self._shaped_type.shape, 1)

@cached_property
def dtype(self) -> Type:
return self._shaped_type.element_type
Expand Down
27 changes: 16 additions & 11 deletions mlir/extras/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,6 @@ def find(op):
return matching


@dataclass
class Successor:
op: OpView | Operation
operands: list[Value]
block: Block
pos: int


_np_dtype_to_mlir_type_ctor = {
np.int8: T.i8,
np.int16: T.i16,
Expand Down Expand Up @@ -285,16 +277,29 @@ def new_dec(*args, **kwargs):
return new_dec


@dataclass
class Successor:
op: OpView | Operation
operands: list[Value]
block: Block
pos: int

def __enter__(self):
self.bb_ctx_manager = bb(self)
return self.bb_ctx_manager.__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
self.bb_ctx_manager.__exit__(exc_type, exc_val, exc_tb)


@contextlib.contextmanager
def bb(*preds: tuple[Successor | OpView]):
current_ip = InsertionPoint.current
op = current_ip.block.owner
op_region = op.regions[0]
args = []
if len(preds):
if isinstance(preds[0], OpView):
args = preds[0].operands
elif isinstance(preds[0], Successor):
if isinstance(preds[0], (OpView, Successor)):
args = preds[0].operands
else:
raise NotImplementedError(f"{preds[0]=} not supported.")
Expand Down
48 changes: 44 additions & 4 deletions tests/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ def foo1():
three = constant(3)
cond = two < three
x = cond_br(cond)
with bb(x.true) as (b2, _):
with x.true as (b2, _):
four = constant(4)
return_([])
with bb(x.false) as (b3, _):
with x.false as (b3, _):
five = constant(5)

foo1.emit()
Expand Down Expand Up @@ -431,10 +431,10 @@ def foo1():
x = cond_br(
cond, true_dest_operands=[two, three], false_dest_operands=[two, three]
)
with bb(x.true) as (b2, _):
with x.true as (b2, _):
four = constant(4)
return_([])
with bb(x.false) as (b3, _):
with x.false as (b3, _):
five = constant(5)

foo1.emit()
Expand Down Expand Up @@ -519,3 +519,43 @@ def mod():
"""
)
filecheck(correct, ctx.module)


def test_successor_ctx_manager(ctx: MLIRContext):
@func
def foo1():
one = constant(1)
return_([])
with bb() as (b1, _):
two = constant(2)
three = constant(3)
cond = two < three
x = cond_br(cond)
with x.true as (b2, _):
four = constant(4)
return_([])
with x.false as (b3, _):
five = constant(5)

foo1()
correct = dedent(
"""\
module {
func.func @foo1() {
%c1_i32 = arith.constant 1 : i32
return
^bb1: // no predecessors
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%0 = arith.cmpi ult, %c2_i32, %c3_i32 : i32
cf.cond_br %0, ^bb2, ^bb3
^bb2: // pred: ^bb1
%c4_i32 = arith.constant 4 : i32
return
^bb3: // pred: ^bb1
%c5_i32_0 = arith.constant 5 : i32
return
}
"""
)
filecheck(correct, ctx.module)
11 changes: 10 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

import mlir.extras.types as T
from mlir.extras.dialects.ext.tensor import S
from mlir.extras.dialects.ext.tensor import S, empty
from mlir.extras.dialects.ext.memref import alloc

# noinspection PyUnresolvedReferences
from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext
Expand Down Expand Up @@ -32,3 +33,11 @@ def test_shaped_types(ctx: MLIRContext):

v = vector(3, 3, 3, T.f64())
assert repr(v) == "VectorType(vector<3x3x3xf64>)"


def test_n_elements(ctx: MLIRContext):
ten = empty((1, 2, 3, 4), T.i32())
assert ten.n_elements == 1 * 2 * 3 * 4

mem = alloc((1, 2, 3, 4), T.i32())
assert mem.n_elements == 1 * 2 * 3 * 4