Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mlir/extras/dialects/ext/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def alloca(
):
if loc is None:
loc = get_user_code_loc()
return get_op_result_or_op_results(
_alloc(AllocaOp, sizes, element_type, loc=loc, ip=ip)
)
return _alloc(AllocaOp, sizes, element_type, loc=loc, ip=ip)


def load(mem: Value, indices: Sequence[Union[Value, int]], *, loc=None, ip=None):
Expand Down
34 changes: 33 additions & 1 deletion tests/test_memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import mlir.extras.types as T
from mlir.extras.ast.canonicalize import canonicalize
from mlir.extras.dialects.ext.arith import Scalar, constant
from mlir.extras.dialects.ext.memref import alloc, S
from mlir.extras.dialects.ext.memref import (
alloc,
alloca,
S,
alloca_scope,
alloca_scope_return,
)
from mlir.extras.dialects.ext.scf import (
range_,
yield_,
Expand Down Expand Up @@ -43,6 +49,32 @@ def test_simple_literal_indexing(ctx: MLIRContext):
filecheck(correct, ctx.module)


def test_simple_literal_indexing_alloca(ctx: MLIRContext):
@alloca_scope([])
def demo_scope2():
mem = alloca((10, 22, 333, 4444), T.i32())

w = mem[2, 4, 6, 8]
assert isinstance(w, Scalar)
alloca_scope_return([])

correct = dedent(
"""\
module {
memref.alloca_scope {
%alloca = memref.alloca() : memref<10x22x333x4444xi32>
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c6 = arith.constant 6 : index
%c8 = arith.constant 8 : index
%0 = memref.load %alloca[%c2, %c4, %c6, %c8] : memref<10x22x333x4444xi32>
}
}
"""
)
filecheck(correct, ctx.module)


def test_ellipsis_and_full_slice(ctx: MLIRContext):
mem = alloc((10, 22, 333, 4444), T.i32())

Expand Down