Skip to content

Commit d905bf3

Browse files
committed
[MLIR][Python] Python binding support for AffineIfOp
1 parent 8168088 commit d905bf3

File tree

4 files changed

+139
-1
lines changed

4 files changed

+139
-1
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
407407
}
408408
```
409409
}];
410-
let arguments = (ins Variadic<AnyType>);
410+
let arguments = (ins Variadic<AnyType>,
411+
IntegerSetAttr:$condition);
411412
let results = (outs Variadic<AnyType>:$results);
412413
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
413414

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
557557
let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
558558
}
559559

560+
// Attributes containing integer sets.
561+
def IntegerSetAttr : Attr<
562+
CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
563+
let storageType = [{::mlir::IntegerSetAttr }];
564+
let returnType = [{ ::mlir::IntegerSet }];
565+
let valueType = NoneType;
566+
let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
567+
}
568+
560569
// Base class for array attributes.
561570
class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
562571
let storageType = [{ ::mlir::ArrayAttr }];

mlir/python/mlir/dialects/affine.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,61 @@ def for_(
156156
yield iv, iter_args[0]
157157
else:
158158
yield iv
159+
160+
161+
@_ods_cext.register_operation(_Dialect, replace=True)
162+
class AffineIfOp(AffineIfOp):
163+
"""Specialization for the Affine if op class."""
164+
165+
def __init__(
166+
self,
167+
cond: IntegerSet,
168+
results_: Optional[Type] = None,
169+
*,
170+
cond_operands: Optional[_VariadicResultValueT] = None,
171+
hasElse: bool = False,
172+
loc=None,
173+
ip=None,
174+
):
175+
"""Creates an Affine `if` operation.
176+
177+
- `cond` is the integer set used to determine which regions of code
178+
will be executed.
179+
- `results` are the list of types to be yielded by the operand.
180+
- `cond_operands` is the list of arguments to substitute the
181+
dimensions, then symbols in the `cond` integer set expression to
182+
determine whether they are in the set.
183+
- `hasElse` determines whether the affine if operation has the else
184+
branch.
185+
"""
186+
if results_ is None:
187+
results_ = []
188+
if cond_operands is None:
189+
cond_operands = []
190+
191+
if not (actual_n_inputs := len(cond_operands)) == (
192+
exp_n_inputs := cond.n_inputs
193+
):
194+
raise ValueError(
195+
f"expected {exp_n_inputs} condition operands, got {actual_n_inputs}"
196+
)
197+
198+
operands = []
199+
operands.extend(cond_operands)
200+
results = []
201+
results.extend(results_)
202+
203+
super().__init__(results, cond_operands, cond)
204+
self.regions[0].blocks.append(*[])
205+
if hasElse:
206+
self.regions[1].blocks.append(*[])
207+
208+
@property
209+
def then_block(self):
210+
"""Returns the then block of the if operation."""
211+
return self.regions[0].blocks[0]
212+
213+
@property
214+
def else_block(self):
215+
"""Returns the else block of the if operation."""
216+
return self.regions[1].blocks[0]

mlir/test/python/dialects/affine.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,73 @@ def range_loop_8(lb, ub, memref_v):
265265
add = arith.addi(i, i)
266266
memref.store(add, it, [i])
267267
affine.yield_([it])
268+
269+
270+
# CHECK-LABEL: TEST: testAffineIfWithoutElse
271+
@constructAndPrintInModule
272+
def testAffineIfWithoutElse():
273+
index = IndexType.get()
274+
i32 = IntegerType.get_signless(32)
275+
d0 = AffineDimExpr.get(0)
276+
277+
# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
278+
cond = IntegerSet.get(1, 0, [d0 - 5], [False])
279+
280+
# CHECK-LABEL: func.func @simple_affine_if(
281+
# CHECK-SAME: %[[VAL_0:.*]]: index) {
282+
# CHECK: affine.if #[[$SET0]](%[[VAL_0]]) {
283+
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
284+
# CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
285+
# CHECK: }
286+
# CHECK: return
287+
# CHECK: }
288+
@func.FuncOp.from_py_func(index)
289+
def simple_affine_if(cond_operands):
290+
if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
291+
with InsertionPoint(if_op.then_block):
292+
one = arith.ConstantOp(i32, 1)
293+
add = arith.AddIOp(one, one)
294+
affine.AffineYieldOp([])
295+
return
296+
297+
298+
# CHECK-LABEL: TEST: testAffineIfWithElse
299+
@constructAndPrintInModule
300+
def testAffineIfWithElse():
301+
index = IndexType.get()
302+
i32 = IntegerType.get_signless(32)
303+
d0 = AffineDimExpr.get(0)
304+
305+
# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
306+
cond = IntegerSet.get(1, 0, [d0 - 5], [False])
307+
308+
# CHECK-LABEL: func.func @simple_affine_if_else(
309+
# CHECK-SAME: %[[VAL_0:.*]]: index) {
310+
# CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
311+
# CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32
312+
# CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32
313+
# CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
314+
# CHECK: } else {
315+
# CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32
316+
# CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32
317+
# CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
318+
# CHECK: }
319+
# CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
320+
# CHECK: return
321+
# CHECK: }
322+
323+
@func.FuncOp.from_py_func(index)
324+
def simple_affine_if_else(cond_operands):
325+
if_op = affine.AffineIfOp(
326+
cond, [i32, i32], cond_operands=[cond_operands], hasElse=True
327+
)
328+
with InsertionPoint(if_op.then_block):
329+
x_true = arith.ConstantOp(i32, 0)
330+
y_true = arith.ConstantOp(i32, 1)
331+
affine.AffineYieldOp([x_true, y_true])
332+
with InsertionPoint(if_op.else_block):
333+
x_false = arith.ConstantOp(i32, 2)
334+
y_false = arith.ConstantOp(i32, 3)
335+
affine.AffineYieldOp([x_false, y_false])
336+
add = arith.AddIOp(if_op.results[0], if_op.results[1])
337+
return

0 commit comments

Comments
 (0)