diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 852d3aa131148..7dc6b95c37b87 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2446,4 +2446,41 @@ def AssumeOp : CIR_Op<"assume"> { }]; } +//===----------------------------------------------------------------------===// +// Branch Probability Operations +//===----------------------------------------------------------------------===// + +def ExpectOp : CIR_Op<"expect", + [Pure, AllTypesMatch<["result", "val", "expected"]>]> { + let summary = "Tell the optimizer that two values are likely to be equal."; + let description = [{ + The `cir.expect` operation may take 2 or 3 arguments. + + When the argument `prob` is missing, this operation effectively models the + `__builtin_expect` builtin function. It tells the optimizer that `val` and + `expected` are likely to be equal. + + When the argumen `prob` is present, this operation effectively models the + `__builtin_expect_with_probability` builtin function. It tells the + optimizer that `val` and `expected` are equal to each other with a certain + probability. + + `val` and `expected` must be integers and their types must match. + + The result of this operation is always equal to `val`. + }]; + + let arguments = (ins + CIR_AnyFundamentalIntType:$val, + CIR_AnyFundamentalIntType:$expected, + OptionalAttr:$prob + ); + + let results = (outs CIR_AnyFundamentalIntType:$result); + + let assemblyFormat = [{ + `(` $val`,` $expected (`,` $prob^)? `)` `:` type($val) attr-dict + }]; +} + #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp index cff139a7802df..b6f39d57b9403 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp @@ -100,6 +100,33 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID, mlir::Value complex = builder.createComplexCreate(loc, real, imag); return RValue::get(complex); } + + case Builtin::BI__builtin_expect: + case Builtin::BI__builtin_expect_with_probability: { + mlir::Value argValue = emitScalarExpr(e->getArg(0)); + mlir::Value expectedValue = emitScalarExpr(e->getArg(1)); + + mlir::FloatAttr probAttr; + if (builtinIDIfNoAsmLabel == Builtin::BI__builtin_expect_with_probability) { + llvm::APFloat probability(0.0); + const Expr *probArg = e->getArg(2); + bool evalSucceeded = + probArg->EvaluateAsFloat(probability, cgm.getASTContext()); + assert(evalSucceeded && + "probability should be able to evaluate as float"); + (void)evalSucceeded; + bool loseInfo = false; + probability.convert(llvm::APFloat::IEEEdouble(), + llvm::RoundingMode::Dynamic, &loseInfo); + probAttr = mlir::FloatAttr::get(mlir::Float64Type::get(&getMLIRContext()), + probability); + } + + auto result = builder.create(getLoc(e->getSourceRange()), + argValue.getType(), argValue, + expectedValue, probAttr); + return RValue::get(result); + } } cgm.errorNYI(e->getSourceRange(), "unimplemented builtin call"); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 5f41e340e2474..8ee49b4b72f69 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -989,6 +989,19 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMExpectOpLowering::matchAndRewrite( + cir::ExpectOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + std::optional prob = op.getProb(); + if (prob) + rewriter.replaceOpWithNewOp( + op, adaptor.getVal(), adaptor.getExpected(), prob.value()); + else + rewriter.replaceOpWithNewOp(op, adaptor.getVal(), + adaptor.getExpected()); + return mlir::success(); +} + /// Convert the `cir.func` attributes to `llvm.func` attributes. /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out @@ -1868,6 +1881,7 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMCallOpLowering, CIRToLLVMCmpOpLowering, CIRToLLVMConstantOpLowering, + CIRToLLVMExpectOpLowering, CIRToLLVMFuncOpLowering, CIRToLLVMGetGlobalOpLowering, CIRToLLVMGetMemberOpLowering, diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index ae7247332c668..52959d61355b0 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -70,6 +70,16 @@ class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern { mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMExpectOpLowering + : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::ExpectOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + class CIRToLLVMReturnOpLowering : public mlir::OpConversionPattern { public: diff --git a/clang/test/CIR/CodeGen/builtin_call.cpp b/clang/test/CIR/CodeGen/builtin_call.cpp index 0a2226a2cc592..bbe5e36b8bd99 100644 --- a/clang/test/CIR/CodeGen/builtin_call.cpp +++ b/clang/test/CIR/CodeGen/builtin_call.cpp @@ -110,3 +110,43 @@ void assume(bool arg) { // OGCG: define {{.*}}void @_Z6assumeb // OGCG: call void @llvm.assume(i1 %{{.+}}) // OGCG: } + +void expect(int x, int y) { + __builtin_expect(x, y); +} + +// CIR-LABEL: cir.func @_Z6expectii +// CIR: %[[X:.+]] = cir.load align(4) %{{.+}} : !cir.ptr, !s32i +// CIR-NEXT: %[[X_LONG:.+]] = cir.cast(integral, %[[X]] : !s32i), !s64i +// CIR-NEXT: %[[Y:.+]] = cir.load align(4) %{{.+}} : !cir.ptr, !s32i +// CIR-NEXT: %[[Y_LONG:.+]] = cir.cast(integral, %[[Y]] : !s32i), !s64i +// CIR-NEXT: %{{.+}} = cir.expect(%[[X_LONG]], %[[Y_LONG]]) : !s64i +// CIR: } + +// LLVM-LABEL: define void @_Z6expectii +// LLVM: %[[X:.+]] = load i32, ptr %{{.+}}, align 4 +// LLVM-NEXT: %[[X_LONG:.+]] = sext i32 %[[X]] to i64 +// LLVM-NEXT: %[[Y:.+]] = load i32, ptr %{{.+}}, align 4 +// LLVM-NEXT: %[[Y_LONG:.+]] = sext i32 %[[Y]] to i64 +// LLVM-NEXT: %{{.+}} = call i64 @llvm.expect.i64(i64 %[[X_LONG]], i64 %[[Y_LONG]]) +// LLVM: } + +void expect_prob(int x, int y) { + __builtin_expect_with_probability(x, y, 0.25); +} + +// CIR-LABEL: cir.func @_Z11expect_probii +// CIR: %[[X:.+]] = cir.load align(4) %{{.+}} : !cir.ptr, !s32i +// CIR-NEXT: %[[X_LONG:.+]] = cir.cast(integral, %[[X]] : !s32i), !s64i +// CIR-NEXT: %[[Y:.+]] = cir.load align(4) %{{.+}} : !cir.ptr, !s32i +// CIR-NEXT: %[[Y_LONG:.+]] = cir.cast(integral, %[[Y]] : !s32i), !s64i +// CIR-NEXT: %{{.+}} = cir.expect(%[[X_LONG]], %[[Y_LONG]], 2.500000e-01) : !s64i +// CIR: } + +// LLVM: define void @_Z11expect_probii +// LLVM: %[[X:.+]] = load i32, ptr %{{.+}}, align 4 +// LLVM-NEXT: %[[X_LONG:.+]] = sext i32 %[[X]] to i64 +// LLVM-NEXT: %[[Y:.+]] = load i32, ptr %{{.+}}, align 4 +// LLVM-NEXT: %[[Y_LONG:.+]] = sext i32 %[[Y]] to i64 +// LLVM-NEXT: %{{.+}} = call i64 @llvm.expect.with.probability.i64(i64 %[[X_LONG]], i64 %[[Y_LONG]], double 2.500000e-01) +// LLVM: }