Skip to content

Commit d50fbe4

Browse files
authored
[MLIR][Python] Python binding support for AffineIfOp (#108323)
Fix the AffineIfOp's default builder such that it takes in an IntegerSetAttr. AffineIfOp has skipDefaultBuilders=1 which effectively skips the creation of the default AffineIfOp::builder on the C++ side. (AffineIfOp has two custom OpBuilder defined in the extraClassDeclaration.) However, on the python side, _affine_ops_gen.py shows that the default builder is being created, but it does not accept IntegerSet and thus is useless. This fix at line 411 makes the default python AffineIfOp builder take in an IntegerSet input and does not impact the C++ side of things.
1 parent 7a8fe0f commit d50fbe4

File tree

4 files changed

+139
-1
lines changed

4 files changed

+139
-1
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

+2-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
407407
}
408408
```
409409
}];
410-
let arguments = (ins Variadic<AnyType>);
410+
let arguments = (ins Variadic<AnyType>,
411+
IntegerSetAttr:$condition);
411412
let results = (outs Variadic<AnyType>:$results);
412413
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
413414

mlir/include/mlir/IR/CommonAttrConstraints.td

+9
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
558558
let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
559559
}
560560

561+
// Attributes containing integer sets.
562+
def IntegerSetAttr : Attr<
563+
CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
564+
let storageType = [{::mlir::IntegerSetAttr }];
565+
let returnType = [{ ::mlir::IntegerSet }];
566+
let valueType = NoneType;
567+
let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
568+
}
569+
561570
// Base class for array attributes.
562571
class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
563572
let storageType = [{ ::mlir::ArrayAttr }];

mlir/python/mlir/dialects/affine.py

+58
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,61 @@ def for_(
156156
yield iv, iter_args[0]
157157
else:
158158
yield iv
159+
160+
161+
@_ods_cext.register_operation(_Dialect, replace=True)
162+
class AffineIfOp(AffineIfOp):
163+
"""Specialization for the Affine if op class."""
164+
165+
def __init__(
166+
self,
167+
cond: IntegerSet,
168+
results_: Optional[Type] = None,
169+
*,
170+
cond_operands: Optional[_VariadicResultValueT] = None,
171+
has_else: bool = False,
172+
loc=None,
173+
ip=None,
174+
):
175+
"""Creates an Affine `if` operation.
176+
177+
- `cond` is the integer set used to determine which regions of code
178+
will be executed.
179+
- `results` are the list of types to be yielded by the operand.
180+
- `cond_operands` is the list of arguments to substitute the
181+
dimensions, then symbols in the `cond` integer set expression to
182+
determine whether they are in the set.
183+
- `has_else` determines whether the affine if operation has the else
184+
branch.
185+
"""
186+
if results_ is None:
187+
results_ = []
188+
if cond_operands is None:
189+
cond_operands = []
190+
191+
if cond.n_inputs != len(cond_operands):
192+
raise ValueError(
193+
f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}"
194+
)
195+
196+
operands = []
197+
operands.extend(cond_operands)
198+
results = []
199+
results.extend(results_)
200+
201+
super().__init__(results, cond_operands, cond)
202+
self.regions[0].blocks.append(*[])
203+
if has_else:
204+
self.regions[1].blocks.append(*[])
205+
206+
@property
207+
def then_block(self) -> Block:
208+
"""Returns the then block of the if operation."""
209+
return self.regions[0].blocks[0]
210+
211+
@property
212+
def else_block(self) -> Optional[Block]:
213+
"""Returns the else block of the if operation."""
214+
if len(self.regions[1].blocks) == 0:
215+
return None
216+
return self.regions[1].blocks[0]

mlir/test/python/dialects/affine.py

+70
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,73 @@ def range_loop_8(lb, ub, memref_v):
263263
add = arith.addi(i, i)
264264
memref.store(add, it, [i])
265265
affine.yield_([it])
266+
267+
268+
# CHECK-LABEL: TEST: testAffineIfWithoutElse
269+
@constructAndPrintInModule
270+
def testAffineIfWithoutElse():
271+
index = IndexType.get()
272+
i32 = IntegerType.get_signless(32)
273+
d0 = AffineDimExpr.get(0)
274+
275+
# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
276+
cond = IntegerSet.get(1, 0, [d0 - 5], [False])
277+
278+
# CHECK-LABEL: func.func @simple_affine_if(
279+
# CHECK-SAME: %[[VAL_0:.*]]: index) {
280+
# CHECK: affine.if #[[$SET0]](%[[VAL_0]]) {
281+
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
282+
# CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
283+
# CHECK: }
284+
# CHECK: return
285+
# CHECK: }
286+
@func.FuncOp.from_py_func(index)
287+
def simple_affine_if(cond_operands):
288+
if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
289+
with InsertionPoint(if_op.then_block):
290+
one = arith.ConstantOp(i32, 1)
291+
add = arith.AddIOp(one, one)
292+
affine.AffineYieldOp([])
293+
return
294+
295+
296+
# CHECK-LABEL: TEST: testAffineIfWithElse
297+
@constructAndPrintInModule
298+
def testAffineIfWithElse():
299+
index = IndexType.get()
300+
i32 = IntegerType.get_signless(32)
301+
d0 = AffineDimExpr.get(0)
302+
303+
# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
304+
cond = IntegerSet.get(1, 0, [d0 - 5], [False])
305+
306+
# CHECK-LABEL: func.func @simple_affine_if_else(
307+
# CHECK-SAME: %[[VAL_0:.*]]: index) {
308+
# CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
309+
# CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32
310+
# CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32
311+
# CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
312+
# CHECK: } else {
313+
# CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32
314+
# CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32
315+
# CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
316+
# CHECK: }
317+
# CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
318+
# CHECK: return
319+
# CHECK: }
320+
321+
@func.FuncOp.from_py_func(index)
322+
def simple_affine_if_else(cond_operands):
323+
if_op = affine.AffineIfOp(
324+
cond, [i32, i32], cond_operands=[cond_operands], has_else=True
325+
)
326+
with InsertionPoint(if_op.then_block):
327+
x_true = arith.ConstantOp(i32, 0)
328+
y_true = arith.ConstantOp(i32, 1)
329+
affine.AffineYieldOp([x_true, y_true])
330+
with InsertionPoint(if_op.else_block):
331+
x_false = arith.ConstantOp(i32, 2)
332+
y_false = arith.ConstantOp(i32, 3)
333+
affine.AffineYieldOp([x_false, y_false])
334+
add = arith.AddIOp(if_op.results[0], if_op.results[1])
335+
return

0 commit comments

Comments
 (0)