Skip to content

[MLIR][Python] Python binding support for AffineIfOp #108323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
}
```
}];
let arguments = (ins Variadic<AnyType>);
let arguments = (ins Variadic<AnyType>,
IntegerSetAttr:$condition);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);

Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/IR/CommonAttrConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
}

// Attributes containing integer sets.
def IntegerSetAttr : Attr<
CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
let storageType = [{::mlir::IntegerSetAttr }];
let returnType = [{ ::mlir::IntegerSet }];
let valueType = NoneType;
let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
}

// Base class for array attributes.
class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
let storageType = [{ ::mlir::ArrayAttr }];
Expand Down
58 changes: 58 additions & 0 deletions mlir/python/mlir/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,61 @@ def for_(
yield iv, iter_args[0]
else:
yield iv


@_ods_cext.register_operation(_Dialect, replace=True)
class AffineIfOp(AffineIfOp):
"""Specialization for the Affine if op class."""

def __init__(
self,
cond: IntegerSet,
results_: Optional[Type] = None,
*,
cond_operands: Optional[_VariadicResultValueT] = None,
has_else: bool = False,
loc=None,
ip=None,
):
"""Creates an Affine `if` operation.

- `cond` is the integer set used to determine which regions of code
will be executed.
- `results` are the list of types to be yielded by the operand.
- `cond_operands` is the list of arguments to substitute the
dimensions, then symbols in the `cond` integer set expression to
determine whether they are in the set.
- `has_else` determines whether the affine if operation has the else
branch.
"""
if results_ is None:
results_ = []
if cond_operands is None:
cond_operands = []

if cond.n_inputs != len(cond_operands):
raise ValueError(
f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}"
)

operands = []
operands.extend(cond_operands)
results = []
results.extend(results_)

super().__init__(results, cond_operands, cond)
self.regions[0].blocks.append(*[])
if has_else:
self.regions[1].blocks.append(*[])

@property
def then_block(self) -> Block:
"""Returns the then block of the if operation."""
return self.regions[0].blocks[0]

@property
def else_block(self) -> Optional[Block]:
"""Returns the else block of the if operation."""
if len(self.regions[1].blocks) == 0:
return None
return self.regions[1].blocks[0]
70 changes: 70 additions & 0 deletions mlir/test/python/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,73 @@ def range_loop_8(lb, ub, memref_v):
add = arith.addi(i, i)
memref.store(add, it, [i])
affine.yield_([it])


# CHECK-LABEL: TEST: testAffineIfWithoutElse
@constructAndPrintInModule
def testAffineIfWithoutElse():
index = IndexType.get()
i32 = IntegerType.get_signless(32)
d0 = AffineDimExpr.get(0)

# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
cond = IntegerSet.get(1, 0, [d0 - 5], [False])

# CHECK-LABEL: func.func @simple_affine_if(
# CHECK-SAME: %[[VAL_0:.*]]: index) {
# CHECK: affine.if #[[$SET0]](%[[VAL_0]]) {
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
# CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(index)
def simple_affine_if(cond_operands):
if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
with InsertionPoint(if_op.then_block):
one = arith.ConstantOp(i32, 1)
add = arith.AddIOp(one, one)
affine.AffineYieldOp([])
return


# CHECK-LABEL: TEST: testAffineIfWithElse
@constructAndPrintInModule
def testAffineIfWithElse():
index = IndexType.get()
i32 = IntegerType.get_signless(32)
d0 = AffineDimExpr.get(0)

# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
cond = IntegerSet.get(1, 0, [d0 - 5], [False])

# CHECK-LABEL: func.func @simple_affine_if_else(
# CHECK-SAME: %[[VAL_0:.*]]: index) {
# CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
# CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32
# CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32
# CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
# CHECK: } else {
# CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32
# CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32
# CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
# CHECK: }
# CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
# CHECK: return
# CHECK: }

@func.FuncOp.from_py_func(index)
def simple_affine_if_else(cond_operands):
if_op = affine.AffineIfOp(
cond, [i32, i32], cond_operands=[cond_operands], has_else=True
)
with InsertionPoint(if_op.then_block):
x_true = arith.ConstantOp(i32, 0)
y_true = arith.ConstantOp(i32, 1)
affine.AffineYieldOp([x_true, y_true])
with InsertionPoint(if_op.else_block):
x_false = arith.ConstantOp(i32, 2)
y_false = arith.ConstantOp(i32, 3)
affine.AffineYieldOp([x_false, y_false])
add = arith.AddIOp(if_op.results[0], if_op.results[1])
return
Loading