diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 53e45c1a6cb3..831d486a5d0e 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -488,12 +488,43 @@ class CIRBrOpLowering : public mlir::OpRewritePattern { } }; +class CIRScopeOpLowering : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::ScopeOp scopeOp, + mlir::PatternRewriter &rewriter) const override { + // Empty scope: just remove it. + if (scopeOp.getRegion().empty()) { + rewriter.eraseOp(scopeOp); + return mlir::success(); + } + + for (auto &block : scopeOp.getRegion()) { + rewriter.setInsertionPointToEnd(&block); + auto *terminator = block.getTerminator(); + rewriter.replaceOpWithNewOp( + terminator, terminator->getOperands()); + } + + rewriter.setInsertionPoint(scopeOp); + auto newScopeOp = rewriter.create( + scopeOp.getLoc(), scopeOp.getResultTypes()); + rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), + newScopeOp.getBodyRegion(), + newScopeOp.getBodyRegion().end()); + rewriter.replaceOp(scopeOp, newScopeOp); + + return mlir::LogicalResult::success(); + } +}; + void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); + CIRReturnLowering, CIRScopeOpLowering>(patterns.getContext()); patterns.add(converter, patterns.getContext()); } diff --git a/clang/test/CIR/Lowering/ThroughMLIR/scope.cir b/clang/test/CIR/Lowering/ThroughMLIR/scope.cir new file mode 100644 index 000000000000..e5e885286b02 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/scope.cir @@ -0,0 +1,52 @@ +// RUN: cir-opt %s -cir-to-mlir -o - | FileCheck %s -check-prefix=MLIR +// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM + +module { + cir.func @foo() { + cir.scope { + %0 = cir.alloca i32, cir.ptr , ["a", init] {alignment = 4 : i64} + %1 = cir.const(4 : i32) : i32 + cir.store %1, %0 : i32, cir.ptr + } + cir.return + } + +// MLIR: func.func @foo() +// MLIR-NEXT: memref.alloca_scope +// MLIR-NEXT: %alloca = memref.alloca() {alignment = 4 : i64} : memref +// MLIR-NEXT: %c4_i32 = arith.constant 4 : i32 +// MLIR-NEXT: memref.store %c4_i32, %alloca[] : memref +// MLIR-NEXT: } +// MLIR-NEXT: return + + +// LLVM: define void @foo() +// LLVM-NEXT: %1 = call ptr @llvm.stacksave() +// LLVM-NEXT: br label %2 +// LLVM-EMPTY: +// LLVM-NEXT: 2: +// LLVM-NEXT: %3 = alloca i32, i64 1, align 4 +// LLVM-NEXT: %4 = insertvalue { ptr, ptr, i64 } undef, ptr %3, 0 +// LLVM-NEXT: %5 = insertvalue { ptr, ptr, i64 } %4, ptr %3, 1 +// LLVM-NEXT: %6 = insertvalue { ptr, ptr, i64 } %5, i64 0, 2 +// LLVM-NEXT: %7 = extractvalue { ptr, ptr, i64 } %6, 1 +// LLVM-NEXT: store i32 4, ptr %7, align 4 +// LLVM-NEXT: call void @llvm.stackrestore(ptr %1) +// LLVM-NEXT: br label %8 +// LLVM-EMPTY: +// LLVM-NEXT: 8: +// LLVM-NEXT: ret void +// LLVM-NEXT: } + + + // Should drop empty scopes. + cir.func @empty_scope() { + cir.scope { + } + cir.return + } + // MLIR: func.func @empty_scope() + // MLIR-NEXT: return + // MLIR-NEXT: } + +}