From 6e9d25dd6cb3b8d7d0c2ec939fe185889f66b1aa Mon Sep 17 00:00:00 2001 From: max Date: Sat, 20 Jan 2024 10:02:39 -0600 Subject: [PATCH] fix alloca (don't get_op_results twice) --- mlir/extras/dialects/ext/memref.py | 4 +--- tests/test_memref.py | 34 +++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/mlir/extras/dialects/ext/memref.py b/mlir/extras/dialects/ext/memref.py index becd248e..fa04f73e 100644 --- a/mlir/extras/dialects/ext/memref.py +++ b/mlir/extras/dialects/ext/memref.py @@ -45,9 +45,7 @@ def alloca( ): if loc is None: loc = get_user_code_loc() - return get_op_result_or_op_results( - _alloc(AllocaOp, sizes, element_type, loc=loc, ip=ip) - ) + return _alloc(AllocaOp, sizes, element_type, loc=loc, ip=ip) def load(mem: Value, indices: Sequence[Union[Value, int]], *, loc=None, ip=None): diff --git a/tests/test_memref.py b/tests/test_memref.py index d5415d80..27a59998 100644 --- a/tests/test_memref.py +++ b/tests/test_memref.py @@ -7,7 +7,13 @@ import mlir.extras.types as T from mlir.extras.ast.canonicalize import canonicalize from mlir.extras.dialects.ext.arith import Scalar, constant -from mlir.extras.dialects.ext.memref import alloc, S +from mlir.extras.dialects.ext.memref import ( + alloc, + alloca, + S, + alloca_scope, + alloca_scope_return, +) from mlir.extras.dialects.ext.scf import ( range_, yield_, @@ -43,6 +49,32 @@ def test_simple_literal_indexing(ctx: MLIRContext): filecheck(correct, ctx.module) +def test_simple_literal_indexing_alloca(ctx: MLIRContext): + @alloca_scope([]) + def demo_scope2(): + mem = alloca((10, 22, 333, 4444), T.i32()) + + w = mem[2, 4, 6, 8] + assert isinstance(w, Scalar) + alloca_scope_return([]) + + correct = dedent( + """\ + module { + memref.alloca_scope { + %alloca = memref.alloca() : memref<10x22x333x4444xi32> + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c6 = arith.constant 6 : index + %c8 = arith.constant 8 : index + %0 = memref.load %alloca[%c2, %c4, %c6, %c8] : memref<10x22x333x4444xi32> + } + } + """ + ) + filecheck(correct, ctx.module) + + def test_ellipsis_and_full_slice(ctx: MLIRContext): mem = alloc((10, 22, 333, 4444), T.i32())