diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index c9d9202ae3cf1..ac0cf36396fa8 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if", } ``` }]; - let arguments = (ins Variadic); + let arguments = (ins Variadic, + IntegerSetAttr:$condition); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion); diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index de5f6797235e3..17ca82c510f8a 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -558,6 +558,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> { let constBuilderCall = "::mlir::AffineMapAttr::get($0)"; } +// Attributes containing integer sets. +def IntegerSetAttr : Attr< +CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> { + let storageType = [{::mlir::IntegerSetAttr }]; + let returnType = [{ ::mlir::IntegerSet }]; + let valueType = NoneType; + let constBuilderCall = "::mlir::IntegerSetAttr::get($0)"; +} + // Base class for array attributes. class ArrayAttrBase : Attr { let storageType = [{ ::mlir::ArrayAttr }]; diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py index 913cea61105ce..7641d36e39799 100644 --- a/mlir/python/mlir/dialects/affine.py +++ b/mlir/python/mlir/dialects/affine.py @@ -156,3 +156,61 @@ def for_( yield iv, iter_args[0] else: yield iv + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AffineIfOp(AffineIfOp): + """Specialization for the Affine if op class.""" + + def __init__( + self, + cond: IntegerSet, + results_: Optional[Type] = None, + *, + cond_operands: Optional[_VariadicResultValueT] = None, + has_else: bool = False, + loc=None, + ip=None, + ): + """Creates an Affine `if` operation. + + - `cond` is the integer set used to determine which regions of code + will be executed. + - `results` are the list of types to be yielded by the operand. + - `cond_operands` is the list of arguments to substitute the + dimensions, then symbols in the `cond` integer set expression to + determine whether they are in the set. + - `has_else` determines whether the affine if operation has the else + branch. + """ + if results_ is None: + results_ = [] + if cond_operands is None: + cond_operands = [] + + if cond.n_inputs != len(cond_operands): + raise ValueError( + f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}" + ) + + operands = [] + operands.extend(cond_operands) + results = [] + results.extend(results_) + + super().__init__(results, cond_operands, cond) + self.regions[0].blocks.append(*[]) + if has_else: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self) -> Block: + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self) -> Optional[Block]: + """Returns the else block of the if operation.""" + if len(self.regions[1].blocks) == 0: + return None + return self.regions[1].blocks[0] diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py index 0dc69d7ba522d..7faae6ccedc97 100644 --- a/mlir/test/python/dialects/affine.py +++ b/mlir/test/python/dialects/affine.py @@ -263,3 +263,73 @@ def range_loop_8(lb, ub, memref_v): add = arith.addi(i, i) memref.store(add, it, [i]) affine.yield_([it]) + + +# CHECK-LABEL: TEST: testAffineIfWithoutElse +@constructAndPrintInModule +def testAffineIfWithoutElse(): + index = IndexType.get() + i32 = IntegerType.get_signless(32) + d0 = AffineDimExpr.get(0) + + # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)> + cond = IntegerSet.get(1, 0, [d0 - 5], [False]) + + # CHECK-LABEL: func.func @simple_affine_if( + # CHECK-SAME: %[[VAL_0:.*]]: index) { + # CHECK: affine.if #[[$SET0]](%[[VAL_0]]) { + # CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32 + # CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32 + # CHECK: } + # CHECK: return + # CHECK: } + @func.FuncOp.from_py_func(index) + def simple_affine_if(cond_operands): + if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands]) + with InsertionPoint(if_op.then_block): + one = arith.ConstantOp(i32, 1) + add = arith.AddIOp(one, one) + affine.AffineYieldOp([]) + return + + +# CHECK-LABEL: TEST: testAffineIfWithElse +@constructAndPrintInModule +def testAffineIfWithElse(): + index = IndexType.get() + i32 = IntegerType.get_signless(32) + d0 = AffineDimExpr.get(0) + + # CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)> + cond = IntegerSet.get(1, 0, [d0 - 5], [False]) + + # CHECK-LABEL: func.func @simple_affine_if_else( + # CHECK-SAME: %[[VAL_0:.*]]: index) { + # CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) { + # CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32 + # CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32 + # CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32 + # CHECK: } else { + # CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32 + # CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32 + # CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32 + # CHECK: } + # CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32 + # CHECK: return + # CHECK: } + + @func.FuncOp.from_py_func(index) + def simple_affine_if_else(cond_operands): + if_op = affine.AffineIfOp( + cond, [i32, i32], cond_operands=[cond_operands], has_else=True + ) + with InsertionPoint(if_op.then_block): + x_true = arith.ConstantOp(i32, 0) + y_true = arith.ConstantOp(i32, 1) + affine.AffineYieldOp([x_true, y_true]) + with InsertionPoint(if_op.else_block): + x_false = arith.ConstantOp(i32, 2) + y_false = arith.ConstantOp(i32, 3) + affine.AffineYieldOp([x_false, y_false]) + add = arith.AddIOp(if_op.results[0], if_op.results[1]) + return