@@ -240,6 +240,47 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
240
240
return mlir::success ();
241
241
}
242
242
243
+ void makeYieldIf (mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op,
244
+ mlir::Block *to,
245
+ mlir::ConversionPatternRewriter &rewriter) const {
246
+ if (op.getKind () == kind) {
247
+ rewriter.setInsertionPoint (op);
248
+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(op, op.getArgs (), to);
249
+ }
250
+ }
251
+
252
+ void
253
+ lowerNestedBreakContinue (mlir::Region &loopBody, mlir::Block *exitBlock,
254
+ mlir::Block *continueBlock,
255
+ mlir::ConversionPatternRewriter &rewriter) const {
256
+
257
+ auto processBreak = [&](mlir::Operation *op) {
258
+ if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
259
+ *op)) // don't process breaks in nested loops and switches
260
+ return mlir::WalkResult::skip ();
261
+
262
+ if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
263
+ makeYieldIf (mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);
264
+
265
+ return mlir::WalkResult::advance ();
266
+ };
267
+
268
+ auto processContinue = [&](mlir::Operation *op) {
269
+ if (isa<mlir::cir::LoopOp>(
270
+ *op)) // don't process continues in nested loops
271
+ return mlir::WalkResult::skip ();
272
+
273
+ if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
274
+ makeYieldIf (mlir::cir::YieldOpKind::Continue, yield, continueBlock,
275
+ rewriter);
276
+
277
+ return mlir::WalkResult::advance ();
278
+ };
279
+
280
+ loopBody.walk <mlir::WalkOrder::PreOrder>(processBreak);
281
+ loopBody.walk <mlir::WalkOrder::PreOrder>(processContinue);
282
+ }
283
+
243
284
mlir::LogicalResult
244
285
matchAndRewrite (mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
245
286
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -267,6 +308,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
267
308
auto &stepFrontBlock = stepRegion.front ();
268
309
auto stepYield =
269
310
dyn_cast<mlir::cir::YieldOp>(stepRegion.back ().getTerminator ());
311
+ auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
312
+
313
+ lowerNestedBreakContinue (bodyRegion, continueBlock, &stepBlock, rewriter);
270
314
271
315
// Move loop op region contents to current CFG.
272
316
rewriter.inlineRegionBefore (condRegion, continueBlock);
@@ -289,8 +333,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
289
333
290
334
// Branch from body to condition or to step on for-loop cases.
291
335
rewriter.setInsertionPoint (bodyYield);
292
- auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
293
- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(bodyYield, &bodyExit);
336
+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(bodyYield, &stepBlock);
294
337
295
338
// Is a for loop: branch from step to condition.
296
339
if (kind == LoopKind::For) {
@@ -488,6 +531,11 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
488
531
}
489
532
};
490
533
534
+ static bool isLoopYield (mlir::cir::YieldOp &op) {
535
+ return op.getKind () == mlir::cir::YieldOpKind::Break ||
536
+ op.getKind () == mlir::cir::YieldOpKind::Continue;
537
+ }
538
+
491
539
class CIRIfLowering : public mlir ::OpConversionPattern<mlir::cir::IfOp> {
492
540
public:
493
541
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -516,8 +564,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
516
564
rewriter.setInsertionPointToEnd (thenAfterBody);
517
565
if (auto thenYieldOp =
518
566
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator ())) {
519
- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
520
- thenYieldOp, thenYieldOp.getArgs (), continueBlock);
567
+ if (!isLoopYield (thenYieldOp)) // lowering of parent loop yields is
568
+ // deferred to loop lowering
569
+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
570
+ thenYieldOp, thenYieldOp.getArgs (), continueBlock);
521
571
} else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator ())) {
522
572
llvm_unreachable (" what are we terminating with?" );
523
573
}
@@ -545,8 +595,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
545
595
rewriter.setInsertionPointToEnd (elseAfterBody);
546
596
if (auto elseYieldOp =
547
597
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator ())) {
548
- rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
549
- elseYieldOp, elseYieldOp.getArgs (), continueBlock);
598
+ if (!isLoopYield (elseYieldOp)) // lowering of parent loop yields is
599
+ // deferred to loop lowering
600
+ rewriter.replaceOpWithNewOp <mlir::cir::BrOp>(
601
+ elseYieldOp, elseYieldOp.getArgs (), continueBlock);
550
602
} else if (!dyn_cast<mlir::cir::ReturnOp>(
551
603
elseAfterBody->getTerminator ())) {
552
604
llvm_unreachable (" what are we terminating with?" );
@@ -1109,6 +1161,9 @@ class CIRSwitchOpLowering
1109
1161
case mlir::cir::YieldOpKind::Break:
1110
1162
rewriteYieldOp (rewriter, yieldOp, exitBlock);
1111
1163
break ;
1164
+ case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
1165
+ // loop lowering
1166
+ break ;
1112
1167
default :
1113
1168
return op->emitError (" invalid yield kind in case statement" );
1114
1169
}
@@ -1692,8 +1747,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
1692
1747
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering,
1693
1748
CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering,
1694
1749
CIRStructElementAddrOpLowering, CIRSwitchOpLowering,
1695
- CIRPtrDiffOpLowering>(
1696
- converter, patterns.getContext ());
1750
+ CIRPtrDiffOpLowering>(converter, patterns.getContext ());
1697
1751
}
1698
1752
1699
1753
namespace {
0 commit comments