Skip to content

[CIR][CIRGen][builtin] handle __lzcnt #1382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 12, 2025
Merged
28 changes: 19 additions & 9 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1559,15 +1559,24 @@ def ComplexBinOp : CIR_Op<"complex.binop",
//===----------------------------------------------------------------------===//

class CIR_BitOp<string mnemonic, TypeConstraint inputTy>
: CIR_Op<mnemonic, [Pure]> {
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins inputTy:$input);
let results = (outs SInt32:$result);
let results = (outs inputTy:$result);

let assemblyFormat = [{
`(` $input `:` type($input) `)` `:` type($result) attr-dict
}];
}

class CIR_CountZerosBitOp<string mnemonic, TypeConstraint inputTy>
: CIR_BitOp<mnemonic, inputTy> {
let arguments = (ins inputTy:$input, UnitAttr:$is_zero_poison);
let assemblyFormat = [{
`(` $input `:` type($input) `)` (`zero_poison` $is_zero_poison^)?
`:` type($result) attr-dict
}];
}

def BitClrsbOp : CIR_BitOp<"bit.clrsb", AnyTypeOf<[SInt32, SInt64]>> {
let summary = "Get the number of leading redundant sign bits in the input";
let description = [{
Expand Down Expand Up @@ -1599,7 +1608,7 @@ def BitClrsbOp : CIR_BitOp<"bit.clrsb", AnyTypeOf<[SInt32, SInt64]>> {
}];
}

def BitClzOp : CIR_BitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
def BitClzOp : CIR_CountZerosBitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of leading 0-bits in the input";
let description = [{
Compute the number of leading 0-bits in the input.
Expand All @@ -1608,23 +1617,23 @@ def BitClzOp : CIR_BitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
returns the number of consecutive 0-bits at the most significant bit
position in the input.

This operation invokes undefined behavior if the input value is 0.
Zero_poison attribute means this operation invokes undefined behavior if the
input value is 0.

Example:

```mlir
!s32i = !cir.int<s, 32>
!u32i = !cir.int<u, 32>

// %0 = 0b0000_0000_0000_0000_0000_0000_0000_1000
%0 = cir.const #cir.int<8> : !u32i
// %1 will be 28
%1 = cir.bit.clz(%0 : !u32i) : !s32i
%1 = cir.bit.clz(%0 : !u32i) zero_poison : !u32i
```
}];
}

def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
def BitCtzOp : CIR_CountZerosBitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
let summary = "Get the number of trailing 0-bits in the input";
let description = [{
Compute the number of trailing 0-bits in the input.
Expand All @@ -1633,7 +1642,8 @@ def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
returns the number of consecutive 0-bits at the least significant bit
position in the input.

This operation invokes undefined behavior if the input value is 0.
Zero_poison attribute means this operation invokes undefined behavior if the
input value is 0.

Example:

Expand All @@ -1644,7 +1654,7 @@ def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
// %0 = 0b1000
%0 = cir.const #cir.int<8> : !u32i
// %1 will be 3
%1 = cir.bit.ctz(%0 : !u32i) : !s32i
%1 = cir.bit.ctz(%0 : !u32i) : !u32i
```
}];
}
Expand Down
61 changes: 40 additions & 21 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,30 @@ static mlir::Value emitBinaryMaybeConstrainedFPBuiltin(CIRGenFunction &CGF,
}

template <typename Op>
static RValue
emitBuiltinBitOp(CIRGenFunction &CGF, const CallExpr *E,
std::optional<CIRGenFunction::BuiltinCheckKind> CK) {
static RValue emitBuiltinBitOp(
CIRGenFunction &CGF, const CallExpr *E,
std::optional<CIRGenFunction::BuiltinCheckKind> CK = std::nullopt,
bool isZeroPoison = false, bool convertToInt = true) {
mlir::Value arg;
if (CK.has_value())
arg = CGF.emitCheckedArgForBuiltin(E->getArg(0), *CK);
else
arg = CGF.emitScalarExpr(E->getArg(0));

auto resultTy = CGF.convertType(E->getType());
auto op =
CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), resultTy, arg);
return RValue::get(op);
Op op;
if constexpr (std::is_same_v<Op, cir::BitClzOp> ||
std::is_same_v<Op, cir::BitCtzOp>) {
op = CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), arg,
isZeroPoison);
} else {
op = CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), arg);
}
const mlir::Value bitResult = op.getResult();
if (const auto si32Ty = CGF.getBuilder().getSInt32Ty();
convertToInt && arg.getType() != si32Ty) {
return RValue::get(CGF.getBuilder().createIntCast(bitResult, si32Ty));
}
return RValue::get(bitResult);
}

// Initialize the alloca with the given size and alignment according to the lang
Expand Down Expand Up @@ -1052,46 +1063,54 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,

case Builtin::BI__builtin_clrsb:
case Builtin::BI__builtin_clrsbl:
case Builtin::BI__builtin_clrsbll:
return emitBuiltinBitOp<cir::BitClrsbOp>(*this, E, std::nullopt);
case Builtin::BI__builtin_clrsbll: {
return emitBuiltinBitOp<cir::BitClrsbOp>(*this, E);
}

case Builtin::BI__builtin_ctzs:
case Builtin::BI__builtin_ctz:
case Builtin::BI__builtin_ctzl:
case Builtin::BI__builtin_ctzll:
case Builtin::BI__builtin_ctzg:
return emitBuiltinBitOp<cir::BitCtzOp>(*this, E, BCK_CTZPassedZero);
case Builtin::BI__builtin_ctzg: {
return emitBuiltinBitOp<cir::BitCtzOp>(*this, E, BCK_CTZPassedZero, true);
}

case Builtin::BI__builtin_clzs:
case Builtin::BI__builtin_clz:
case Builtin::BI__builtin_clzl:
case Builtin::BI__builtin_clzll:
case Builtin::BI__builtin_clzg:
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero);
case Builtin::BI__builtin_clzg: {
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero, true);
}

case Builtin::BI__builtin_ffs:
case Builtin::BI__builtin_ffsl:
case Builtin::BI__builtin_ffsll:
return emitBuiltinBitOp<cir::BitFfsOp>(*this, E, std::nullopt);
case Builtin::BI__builtin_ffsll: {
return emitBuiltinBitOp<cir::BitFfsOp>(*this, E);
}

case Builtin::BI__builtin_parity:
case Builtin::BI__builtin_parityl:
case Builtin::BI__builtin_parityll:
return emitBuiltinBitOp<cir::BitParityOp>(*this, E, std::nullopt);
case Builtin::BI__builtin_parityll: {
return emitBuiltinBitOp<cir::BitParityOp>(*this, E);
}

case Builtin::BI__lzcnt16:
case Builtin::BI__lzcnt:
case Builtin::BI__lzcnt64:
llvm_unreachable("BI__lzcnt16 like NYI");
case Builtin::BI__lzcnt64: {
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero, false,
false);
}

case Builtin::BI__popcnt16:
case Builtin::BI__popcnt:
case Builtin::BI__popcnt64:
case Builtin::BI__builtin_popcount:
case Builtin::BI__builtin_popcountl:
case Builtin::BI__builtin_popcountll:
case Builtin::BI__builtin_popcountg:
return emitBuiltinBitOp<cir::BitPopcountOp>(*this, E, std::nullopt);
case Builtin::BI__builtin_popcountg: {
return emitBuiltinBitOp<cir::BitPopcountOp>(*this, E);
}

case Builtin::BI__builtin_unpredictable: {
if (CGM.getCodeGenOpts().OptimizationLevel != 0)
Expand Down
61 changes: 12 additions & 49 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3059,38 +3059,6 @@ mlir::LogicalResult CIRToLLVMAssumeSepStorageOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::Value createLLVMBitOp(mlir::Location loc,
const llvm::Twine &llvmIntrinBaseName,
mlir::Type resultTy, mlir::Value operand,
std::optional<bool> poisonZeroInputFlag,
mlir::ConversionPatternRewriter &rewriter) {
auto operandIntTy = mlir::cast<mlir::IntegerType>(operand.getType());
auto resultIntTy = mlir::cast<mlir::IntegerType>(resultTy);

std::string llvmIntrinName =
llvmIntrinBaseName.concat(".i")
.concat(std::to_string(operandIntTy.getWidth()))
.str();

// Note that LLVM intrinsic calls to bit intrinsics have the same type as the
// operand.
mlir::LLVM::CallIntrinsicOp op;
if (poisonZeroInputFlag.has_value()) {
auto poisonZeroInputValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, rewriter.getI1Type(), static_cast<int64_t>(*poisonZeroInputFlag));
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
operand.getType(),
{operand, poisonZeroInputValue});
} else {
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
operand.getType(), operand);
}

return getLLVMIntCast(
rewriter, op->getResult(0), mlir::cast<mlir::IntegerType>(resultTy),
/*isUnsigned=*/true, operandIntTy.getWidth(), resultIntTy.getWidth());
}

mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
cir::BitClrsbOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand All @@ -3111,8 +3079,8 @@ mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
op.getLoc(), isNeg, flipped, adaptor.getInput());

auto resTy = getTypeConverter()->convertType(op.getType());
auto clz = createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, select,
/*poisonZeroInputFlag=*/false, rewriter);
auto clz = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
op.getLoc(), resTy, select, false);

auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one);
Expand Down Expand Up @@ -3147,9 +3115,8 @@ mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite(
cir::BitClzOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp =
createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/true, rewriter);
auto llvmOp = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
op.getLoc(), resTy, adaptor.getInput(), op.getIsZeroPoison());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}
Expand All @@ -3158,9 +3125,8 @@ mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite(
cir::BitCtzOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp =
createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/true, rewriter);
auto llvmOp = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
op.getLoc(), resTy, adaptor.getInput(), op.getIsZeroPoison());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}
Expand All @@ -3169,9 +3135,8 @@ mlir::LogicalResult CIRToLLVMBitFfsOpLowering::matchAndRewrite(
cir::BitFfsOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto ctz =
createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/false, rewriter);
auto ctz = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
op.getLoc(), resTy, adaptor.getInput(), false);

auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
auto ctzAddOne = rewriter.create<mlir::LLVM::AddOp>(op.getLoc(), ctz, one);
Expand All @@ -3196,9 +3161,8 @@ mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite(
cir::BitParityOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto popcnt =
createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/std::nullopt, rewriter);
auto popcnt = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
adaptor.getInput());

auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
auto popcntMod2 =
Expand All @@ -3212,9 +3176,8 @@ mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite(
cir::BitPopcountOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto resTy = getTypeConverter()->convertType(op.getType());
auto llvmOp =
createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(),
/*poisonZeroInputFlag=*/std::nullopt, rewriter);
auto llvmOp = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
adaptor.getInput());
rewriter.replaceOp(op, llvmOp);
return mlir::LogicalResult::success();
}
Expand Down
6 changes: 0 additions & 6 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp(
const llvm::Twine &intrinsicName, mlir::Type resultTy,
mlir::ValueRange operands);

mlir::Value createLLVMBitOp(mlir::Location loc,
const llvm::Twine &llvmIntrinBaseName,
mlir::Type resultTy, mlir::Value operand,
std::optional<bool> poisonZeroInputFlag,
mlir::ConversionPatternRewriter &rewriter);

class CIRToLLVMCopyOpLowering : public mlir::OpConversionPattern<cir::CopyOp> {
public:
using mlir::OpConversionPattern<cir::CopyOp>::OpConversionPattern;
Expand Down
Loading