Skip to content

Commit 74366dc

Browse files
gitoleglanza
authored andcommitted
[CIR][Lowering] Fixed break/continue lowering for loops (llvm#211)
This PR fixes lowering for `break/continue` in loops. The idea is to replace `cir.yield break` and `cir.yield continue` with the branch operations to the corresponding blocks. Note, that we need to ignore nesting loops and don't touch `break` in switch operations. Also, `yield` from `if` need to be considered only when it's not the loop `yield` and `continue` in switch is ignored since it's processed in the loops lowering. Fixes llvm#160
1 parent 23366a9 commit 74366dc

File tree

3 files changed

+702
-8
lines changed

3 files changed

+702
-8
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,47 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
240240
return mlir::success();
241241
}
242242

243+
void makeYieldIf(mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op,
244+
mlir::Block *to,
245+
mlir::ConversionPatternRewriter &rewriter) const {
246+
if (op.getKind() == kind) {
247+
rewriter.setInsertionPoint(op);
248+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, op.getArgs(), to);
249+
}
250+
}
251+
252+
void
253+
lowerNestedBreakContinue(mlir::Region &loopBody, mlir::Block *exitBlock,
254+
mlir::Block *continueBlock,
255+
mlir::ConversionPatternRewriter &rewriter) const {
256+
257+
auto processBreak = [&](mlir::Operation *op) {
258+
if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
259+
*op)) // don't process breaks in nested loops and switches
260+
return mlir::WalkResult::skip();
261+
262+
if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
263+
makeYieldIf(mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);
264+
265+
return mlir::WalkResult::advance();
266+
};
267+
268+
auto processContinue = [&](mlir::Operation *op) {
269+
if (isa<mlir::cir::LoopOp>(
270+
*op)) // don't process continues in nested loops
271+
return mlir::WalkResult::skip();
272+
273+
if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
274+
makeYieldIf(mlir::cir::YieldOpKind::Continue, yield, continueBlock,
275+
rewriter);
276+
277+
return mlir::WalkResult::advance();
278+
};
279+
280+
loopBody.walk<mlir::WalkOrder::PreOrder>(processBreak);
281+
loopBody.walk<mlir::WalkOrder::PreOrder>(processContinue);
282+
}
283+
243284
mlir::LogicalResult
244285
matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
245286
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -267,6 +308,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
267308
auto &stepFrontBlock = stepRegion.front();
268309
auto stepYield =
269310
dyn_cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());
311+
auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
312+
313+
lowerNestedBreakContinue(bodyRegion, continueBlock, &stepBlock, rewriter);
270314

271315
// Move loop op region contents to current CFG.
272316
rewriter.inlineRegionBefore(condRegion, continueBlock);
@@ -289,8 +333,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
289333

290334
// Branch from body to condition or to step on for-loop cases.
291335
rewriter.setInsertionPoint(bodyYield);
292-
auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
293-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &bodyExit);
336+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &stepBlock);
294337

295338
// Is a for loop: branch from step to condition.
296339
if (kind == LoopKind::For) {
@@ -488,6 +531,11 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
488531
}
489532
};
490533

534+
static bool isLoopYield(mlir::cir::YieldOp &op) {
535+
return op.getKind() == mlir::cir::YieldOpKind::Break ||
536+
op.getKind() == mlir::cir::YieldOpKind::Continue;
537+
}
538+
491539
class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
492540
public:
493541
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -516,8 +564,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
516564
rewriter.setInsertionPointToEnd(thenAfterBody);
517565
if (auto thenYieldOp =
518566
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
519-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
520-
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
567+
if (!isLoopYield(thenYieldOp)) // lowering of parent loop yields is
568+
// deferred to loop lowering
569+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
570+
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
521571
} else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator())) {
522572
llvm_unreachable("what are we terminating with?");
523573
}
@@ -545,8 +595,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
545595
rewriter.setInsertionPointToEnd(elseAfterBody);
546596
if (auto elseYieldOp =
547597
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
548-
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
549-
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
598+
if (!isLoopYield(elseYieldOp)) // lowering of parent loop yields is
599+
// deferred to loop lowering
600+
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
601+
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
550602
} else if (!dyn_cast<mlir::cir::ReturnOp>(
551603
elseAfterBody->getTerminator())) {
552604
llvm_unreachable("what are we terminating with?");
@@ -1109,6 +1161,9 @@ class CIRSwitchOpLowering
11091161
case mlir::cir::YieldOpKind::Break:
11101162
rewriteYieldOp(rewriter, yieldOp, exitBlock);
11111163
break;
1164+
case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
1165+
// loop lowering
1166+
break;
11121167
default:
11131168
return op->emitError("invalid yield kind in case statement");
11141169
}
@@ -1692,8 +1747,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
16921747
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering,
16931748
CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering,
16941749
CIRStructElementAddrOpLowering, CIRSwitchOpLowering,
1695-
CIRPtrDiffOpLowering>(
1696-
converter, patterns.getContext());
1750+
CIRPtrDiffOpLowering>(converter, patterns.getContext());
16971751
}
16981752

16991753
namespace {

0 commit comments

Comments
 (0)