Skip to content

Commit da784a2

Browse files
[mlir][IR] Add RewriterBase::moveBlockBefore and fix bug in moveOpBefore (#79579)
This commit adds a new method to the rewriter API: `moveBlockBefore`. This op is utilized by `inlineRegionBefore` and covered by dialect conversion test cases. Also fixes a bug in `moveOpBefore`, where the previous op location was not passed correctly. Adds a test case to `test-strict-pattern-driver.mlir`.
1 parent b210cbb commit da784a2

File tree

6 files changed

+92
-17
lines changed

6 files changed

+92
-17
lines changed

mlir/include/mlir/IR/Block.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
6767
/// specific block.
6868
void moveBefore(Block *block);
6969

70+
/// Unlink this block from its current region and insert it right before the
71+
/// block that the given iterator points to in the region region.
72+
void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
73+
7074
/// Unlink this Block from its parent region and delete it.
7175
void erase();
7276

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,13 @@ class RewriterBase : public OpBuilder {
614614
virtual void moveOpAfter(Operation *op, Block *block,
615615
Block::iterator iterator);
616616

617+
/// Unlink this block and insert it right before `existingBlock`.
618+
void moveBlockBefore(Block *block, Block *anotherBlock);
619+
620+
/// Unlink this block and insert it right before the location that the given
621+
/// iterator points to in the given region.
622+
void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
623+
617624
/// This method is used to notify the rewriter that an in-place operation
618625
/// modification is about to happen. A call to this function *must* be
619626
/// followed by a call to either `finalizeOpModification` or

mlir/lib/IR/Block.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
5252
/// specific block.
5353
void Block::moveBefore(Block *block) {
5454
assert(block->getParent() && "cannot insert before a block without a parent");
55-
block->getParent()->getBlocks().splice(
56-
block->getIterator(), getParent()->getBlocks(), getIterator());
55+
moveBefore(block->getParent(), block->getIterator());
56+
}
57+
58+
/// Unlink this block from its current region and insert it right before the
59+
/// block that the given iterator points to in the region region.
60+
void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
61+
region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
5762
}
5863

5964
/// Unlink this Block from its parent Region and delete it.

mlir/lib/IR/PatternMatch.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,8 @@ void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
350350
}
351351

352352
// Move blocks from the beginning of the region one-by-one.
353-
while (!region.empty()) {
354-
Block *block = &region.front();
355-
parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
356-
listener->notifyBlockInserted(block, &region, region.begin());
357-
}
353+
while (!region.empty())
354+
moveBlockBefore(&region.front(), &parent, before);
358355
}
359356
void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
360357
inlineRegionBefore(region, *before->getParent(), before->getIterator());
@@ -378,18 +375,33 @@ void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
378375
cloneRegionBefore(region, *before->getParent(), before->getIterator());
379376
}
380377

378+
void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
379+
moveBlockBefore(block, anotherBlock->getParent(),
380+
anotherBlock->getIterator());
381+
}
382+
383+
void RewriterBase::moveBlockBefore(Block *block, Region *region,
384+
Region::iterator iterator) {
385+
Region *currentRegion = block->getParent();
386+
Region::iterator nextIterator = std::next(block->getIterator());
387+
block->moveBefore(region, iterator);
388+
if (listener)
389+
listener->notifyBlockInserted(block, /*previous=*/currentRegion,
390+
/*previousIt=*/nextIterator);
391+
}
392+
381393
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
382394
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
383395
}
384396

385397
void RewriterBase::moveOpBefore(Operation *op, Block *block,
386398
Block::iterator iterator) {
387399
Block *currentBlock = op->getBlock();
388-
Block::iterator currentIterator = op->getIterator();
400+
Block::iterator nextIterator = std::next(op->getIterator());
389401
op->moveBefore(block, iterator);
390402
if (listener)
391403
listener->notifyOperationInserted(
392-
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
404+
op, /*previous=*/InsertPoint(currentBlock, nextIterator));
393405
}
394406

395407
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@@ -398,10 +410,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
398410

399411
void RewriterBase::moveOpAfter(Operation *op, Block *block,
400412
Block::iterator iterator) {
401-
Block *currentBlock = op->getBlock();
402-
Block::iterator currentIterator = op->getIterator();
403-
op->moveAfter(block, iterator);
404-
if (listener)
405-
listener->notifyOperationInserted(
406-
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
413+
assert(iterator != block->end() && "cannot move after end of block");
414+
moveOpBefore(op, block, std::next(iterator));
407415
}

mlir/test/Transforms/test-strict-pattern-driver.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func.func @test_erase() {
2424

2525
// -----
2626

27+
// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
2728
// CHECK-EN-LABEL: func @test_insert_same_op
2829
// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
2930
// CHECK-EN: "test.insert_same_op"() {skip = true}
@@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
3536

3637
// -----
3738

39+
// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
3840
// CHECK-EN-LABEL: func @test_replace_with_new_op
3941
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
4042
// CHECK-EN: %[[n:.*]] = "test.new_op"
@@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
4951

5052
// -----
5153

54+
// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
55+
// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
56+
// CHECK-EN: notifyOperationRemoved: test.erase_op
5257
// CHECK-EN-LABEL: func @test_replace_with_erase_op
5358
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
5459
// CHECK-EN-NOT: "test.replace_with_new_op"
@@ -229,3 +234,18 @@ func.func @test_remove_diamond(%c: i1) {
229234
}) : () -> ()
230235
return
231236
}
237+
238+
// -----
239+
240+
// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
241+
// CHECK-AN-LABEL: func @test_move_op_before(
242+
// CHECK-AN: test.move_before_parent_op
243+
// CHECK-AN: test.op_with_region
244+
// CHECK-AN: test.dummy_terminator
245+
func.func @test_move_op_before() {
246+
"test.op_with_region"() ({
247+
"test.move_before_parent_op"() : () -> ()
248+
"test.dummy_terminator"() : () ->()
249+
}) : () -> ()
250+
return
251+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,21 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
198198
}
199199
};
200200

201+
/// This pattern moves "test.move_before_parent_op" before the parent op.
202+
struct MoveBeforeParentOp : public RewritePattern {
203+
MoveBeforeParentOp(MLIRContext *context)
204+
: RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
205+
206+
LogicalResult matchAndRewrite(Operation *op,
207+
PatternRewriter &rewriter) const override {
208+
// Do not hoist past functions.
209+
if (isa<FunctionOpInterface>(op->getParentOp()))
210+
return failure();
211+
rewriter.moveOpBefore(op, op->getParentOp());
212+
return success();
213+
}
214+
};
215+
201216
struct TestPatternDriver
202217
: public PassWrapper<TestPatternDriver, OperationPass<>> {
203218
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -238,6 +253,20 @@ struct TestPatternDriver
238253
};
239254

240255
struct DumpNotifications : public RewriterBase::Listener {
256+
void notifyOperationInserted(Operation *op,
257+
OpBuilder::InsertPoint previous) override {
258+
llvm::outs() << "notifyOperationInserted: " << op->getName();
259+
if (!previous.isSet()) {
260+
llvm::outs() << ", was unlinked\n";
261+
} else {
262+
if (previous.getPoint() == previous.getBlock()->end()) {
263+
llvm::outs() << ", was last in block\n";
264+
} else {
265+
llvm::outs() << ", previous = " << previous.getPoint()->getName()
266+
<< "\n";
267+
}
268+
}
269+
}
241270
void notifyOperationRemoved(Operation *op) override {
242271
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
243272
}
@@ -267,14 +296,16 @@ struct TestStrictPatternDriver
267296
ReplaceWithNewOp,
268297
EraseOp,
269298
ChangeBlockOp,
270-
ImplicitChangeOp
299+
ImplicitChangeOp,
300+
MoveBeforeParentOp
271301
// clang-format on
272302
>(ctx);
273303
SmallVector<Operation *> ops;
274304
getOperation()->walk([&](Operation *op) {
275305
StringRef opName = op->getName().getStringRef();
276306
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
277-
opName == "test.replace_with_new_op" || opName == "test.erase_op") {
307+
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
308+
opName == "test.move_before_parent_op") {
278309
ops.push_back(op);
279310
}
280311
});

0 commit comments

Comments
 (0)