Skip to content

Commit 6f03ca6

Browse files
authored
[CIR][CIRGen][builtin] handle __lzcnt (#1382)
Traditional clang implementation: https://github.com/llvm/clangir/blob/a1ab6bf6cd3b83d0982c16f29e8c98958f69c024/clang/lib/CodeGen/CGBuiltin.cpp#L3618-L3632 The problem here is that `__builtin_clz` allows undefined result, while `__lzcnt` doesn't. As a result, I have to create a new CIR for `__lzcnt`. Since the return type of those two builtin differs, I decided to change return type of current `CIR_BitOp` to allow new `CIR_LzcntOp` to inherit from it. I would like to hear your suggestions. C.c. @Lancern
1 parent fa5b07c commit 6f03ca6

File tree

11 files changed

+360
-380
lines changed

11 files changed

+360
-380
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,15 +1559,24 @@ def ComplexBinOp : CIR_Op<"complex.binop",
15591559
//===----------------------------------------------------------------------===//
15601560

15611561
class CIR_BitOp<string mnemonic, TypeConstraint inputTy>
1562-
: CIR_Op<mnemonic, [Pure]> {
1562+
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
15631563
let arguments = (ins inputTy:$input);
1564-
let results = (outs SInt32:$result);
1564+
let results = (outs inputTy:$result);
15651565

15661566
let assemblyFormat = [{
15671567
`(` $input `:` type($input) `)` `:` type($result) attr-dict
15681568
}];
15691569
}
15701570

1571+
class CIR_CountZerosBitOp<string mnemonic, TypeConstraint inputTy>
1572+
: CIR_BitOp<mnemonic, inputTy> {
1573+
let arguments = (ins inputTy:$input, UnitAttr:$is_zero_poison);
1574+
let assemblyFormat = [{
1575+
`(` $input `:` type($input) `)` (`zero_poison` $is_zero_poison^)?
1576+
`:` type($result) attr-dict
1577+
}];
1578+
}
1579+
15711580
def BitClrsbOp : CIR_BitOp<"bit.clrsb", AnyTypeOf<[SInt32, SInt64]>> {
15721581
let summary = "Get the number of leading redundant sign bits in the input";
15731582
let description = [{
@@ -1599,7 +1608,7 @@ def BitClrsbOp : CIR_BitOp<"bit.clrsb", AnyTypeOf<[SInt32, SInt64]>> {
15991608
}];
16001609
}
16011610

1602-
def BitClzOp : CIR_BitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
1611+
def BitClzOp : CIR_CountZerosBitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
16031612
let summary = "Get the number of leading 0-bits in the input";
16041613
let description = [{
16051614
Compute the number of leading 0-bits in the input.
@@ -1608,23 +1617,23 @@ def BitClzOp : CIR_BitOp<"bit.clz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
16081617
returns the number of consecutive 0-bits at the most significant bit
16091618
position in the input.
16101619

1611-
This operation invokes undefined behavior if the input value is 0.
1620+
Zero_poison attribute means this operation invokes undefined behavior if the
1621+
input value is 0.
16121622

16131623
Example:
16141624

16151625
```mlir
1616-
!s32i = !cir.int<s, 32>
16171626
!u32i = !cir.int<u, 32>
16181627

16191628
// %0 = 0b0000_0000_0000_0000_0000_0000_0000_1000
16201629
%0 = cir.const #cir.int<8> : !u32i
16211630
// %1 will be 28
1622-
%1 = cir.bit.clz(%0 : !u32i) : !s32i
1631+
%1 = cir.bit.clz(%0 : !u32i) zero_poison : !u32i
16231632
```
16241633
}];
16251634
}
16261635

1627-
def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
1636+
def BitCtzOp : CIR_CountZerosBitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
16281637
let summary = "Get the number of trailing 0-bits in the input";
16291638
let description = [{
16301639
Compute the number of trailing 0-bits in the input.
@@ -1633,7 +1642,8 @@ def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
16331642
returns the number of consecutive 0-bits at the least significant bit
16341643
position in the input.
16351644

1636-
This operation invokes undefined behavior if the input value is 0.
1645+
Zero_poison attribute means this operation invokes undefined behavior if the
1646+
input value is 0.
16371647

16381648
Example:
16391649

@@ -1644,7 +1654,7 @@ def BitCtzOp : CIR_BitOp<"bit.ctz", AnyTypeOf<[UInt16, UInt32, UInt64]>> {
16441654
// %0 = 0b1000
16451655
%0 = cir.const #cir.int<8> : !u32i
16461656
// %1 will be 3
1647-
%1 = cir.bit.ctz(%0 : !u32i) : !s32i
1657+
%1 = cir.bit.ctz(%0 : !u32i) : !u32i
16481658
```
16491659
}];
16501660
}

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,30 @@ static mlir::Value emitBinaryMaybeConstrainedFPBuiltin(CIRGenFunction &CGF,
128128
}
129129

130130
template <typename Op>
131-
static RValue
132-
emitBuiltinBitOp(CIRGenFunction &CGF, const CallExpr *E,
133-
std::optional<CIRGenFunction::BuiltinCheckKind> CK) {
131+
static RValue emitBuiltinBitOp(
132+
CIRGenFunction &CGF, const CallExpr *E,
133+
std::optional<CIRGenFunction::BuiltinCheckKind> CK = std::nullopt,
134+
bool isZeroPoison = false, bool convertToInt = true) {
134135
mlir::Value arg;
135136
if (CK.has_value())
136137
arg = CGF.emitCheckedArgForBuiltin(E->getArg(0), *CK);
137138
else
138139
arg = CGF.emitScalarExpr(E->getArg(0));
139140

140-
auto resultTy = CGF.convertType(E->getType());
141-
auto op =
142-
CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), resultTy, arg);
143-
return RValue::get(op);
141+
Op op;
142+
if constexpr (std::is_same_v<Op, cir::BitClzOp> ||
143+
std::is_same_v<Op, cir::BitCtzOp>) {
144+
op = CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), arg,
145+
isZeroPoison);
146+
} else {
147+
op = CGF.getBuilder().create<Op>(CGF.getLoc(E->getExprLoc()), arg);
148+
}
149+
const mlir::Value bitResult = op.getResult();
150+
if (const auto si32Ty = CGF.getBuilder().getSInt32Ty();
151+
convertToInt && arg.getType() != si32Ty) {
152+
return RValue::get(CGF.getBuilder().createIntCast(bitResult, si32Ty));
153+
}
154+
return RValue::get(bitResult);
144155
}
145156

146157
// Initialize the alloca with the given size and alignment according to the lang
@@ -1052,46 +1063,54 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
10521063

10531064
case Builtin::BI__builtin_clrsb:
10541065
case Builtin::BI__builtin_clrsbl:
1055-
case Builtin::BI__builtin_clrsbll:
1056-
return emitBuiltinBitOp<cir::BitClrsbOp>(*this, E, std::nullopt);
1066+
case Builtin::BI__builtin_clrsbll: {
1067+
return emitBuiltinBitOp<cir::BitClrsbOp>(*this, E);
1068+
}
10571069

10581070
case Builtin::BI__builtin_ctzs:
10591071
case Builtin::BI__builtin_ctz:
10601072
case Builtin::BI__builtin_ctzl:
10611073
case Builtin::BI__builtin_ctzll:
1062-
case Builtin::BI__builtin_ctzg:
1063-
return emitBuiltinBitOp<cir::BitCtzOp>(*this, E, BCK_CTZPassedZero);
1074+
case Builtin::BI__builtin_ctzg: {
1075+
return emitBuiltinBitOp<cir::BitCtzOp>(*this, E, BCK_CTZPassedZero, true);
1076+
}
10641077

10651078
case Builtin::BI__builtin_clzs:
10661079
case Builtin::BI__builtin_clz:
10671080
case Builtin::BI__builtin_clzl:
10681081
case Builtin::BI__builtin_clzll:
1069-
case Builtin::BI__builtin_clzg:
1070-
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero);
1082+
case Builtin::BI__builtin_clzg: {
1083+
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero, true);
1084+
}
10711085

10721086
case Builtin::BI__builtin_ffs:
10731087
case Builtin::BI__builtin_ffsl:
1074-
case Builtin::BI__builtin_ffsll:
1075-
return emitBuiltinBitOp<cir::BitFfsOp>(*this, E, std::nullopt);
1088+
case Builtin::BI__builtin_ffsll: {
1089+
return emitBuiltinBitOp<cir::BitFfsOp>(*this, E);
1090+
}
10761091

10771092
case Builtin::BI__builtin_parity:
10781093
case Builtin::BI__builtin_parityl:
1079-
case Builtin::BI__builtin_parityll:
1080-
return emitBuiltinBitOp<cir::BitParityOp>(*this, E, std::nullopt);
1094+
case Builtin::BI__builtin_parityll: {
1095+
return emitBuiltinBitOp<cir::BitParityOp>(*this, E);
1096+
}
10811097

10821098
case Builtin::BI__lzcnt16:
10831099
case Builtin::BI__lzcnt:
1084-
case Builtin::BI__lzcnt64:
1085-
llvm_unreachable("BI__lzcnt16 like NYI");
1100+
case Builtin::BI__lzcnt64: {
1101+
return emitBuiltinBitOp<cir::BitClzOp>(*this, E, BCK_CLZPassedZero, false,
1102+
false);
1103+
}
10861104

10871105
case Builtin::BI__popcnt16:
10881106
case Builtin::BI__popcnt:
10891107
case Builtin::BI__popcnt64:
10901108
case Builtin::BI__builtin_popcount:
10911109
case Builtin::BI__builtin_popcountl:
10921110
case Builtin::BI__builtin_popcountll:
1093-
case Builtin::BI__builtin_popcountg:
1094-
return emitBuiltinBitOp<cir::BitPopcountOp>(*this, E, std::nullopt);
1111+
case Builtin::BI__builtin_popcountg: {
1112+
return emitBuiltinBitOp<cir::BitPopcountOp>(*this, E);
1113+
}
10951114

10961115
case Builtin::BI__builtin_unpredictable: {
10971116
if (CGM.getCodeGenOpts().OptimizationLevel != 0)

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 12 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3059,38 +3059,6 @@ mlir::LogicalResult CIRToLLVMAssumeSepStorageOpLowering::matchAndRewrite(
30593059
return mlir::success();
30603060
}
30613061

3062-
mlir::Value createLLVMBitOp(mlir::Location loc,
3063-
const llvm::Twine &llvmIntrinBaseName,
3064-
mlir::Type resultTy, mlir::Value operand,
3065-
std::optional<bool> poisonZeroInputFlag,
3066-
mlir::ConversionPatternRewriter &rewriter) {
3067-
auto operandIntTy = mlir::cast<mlir::IntegerType>(operand.getType());
3068-
auto resultIntTy = mlir::cast<mlir::IntegerType>(resultTy);
3069-
3070-
std::string llvmIntrinName =
3071-
llvmIntrinBaseName.concat(".i")
3072-
.concat(std::to_string(operandIntTy.getWidth()))
3073-
.str();
3074-
3075-
// Note that LLVM intrinsic calls to bit intrinsics have the same type as the
3076-
// operand.
3077-
mlir::LLVM::CallIntrinsicOp op;
3078-
if (poisonZeroInputFlag.has_value()) {
3079-
auto poisonZeroInputValue = rewriter.create<mlir::LLVM::ConstantOp>(
3080-
loc, rewriter.getI1Type(), static_cast<int64_t>(*poisonZeroInputFlag));
3081-
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
3082-
operand.getType(),
3083-
{operand, poisonZeroInputValue});
3084-
} else {
3085-
op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName,
3086-
operand.getType(), operand);
3087-
}
3088-
3089-
return getLLVMIntCast(
3090-
rewriter, op->getResult(0), mlir::cast<mlir::IntegerType>(resultTy),
3091-
/*isUnsigned=*/true, operandIntTy.getWidth(), resultIntTy.getWidth());
3092-
}
3093-
30943062
mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
30953063
cir::BitClrsbOp op, OpAdaptor adaptor,
30963064
mlir::ConversionPatternRewriter &rewriter) const {
@@ -3111,8 +3079,8 @@ mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
31113079
op.getLoc(), isNeg, flipped, adaptor.getInput());
31123080

31133081
auto resTy = getTypeConverter()->convertType(op.getType());
3114-
auto clz = createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, select,
3115-
/*poisonZeroInputFlag=*/false, rewriter);
3082+
auto clz = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
3083+
op.getLoc(), resTy, select, false);
31163084

31173085
auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
31183086
auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one);
@@ -3147,9 +3115,8 @@ mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite(
31473115
cir::BitClzOp op, OpAdaptor adaptor,
31483116
mlir::ConversionPatternRewriter &rewriter) const {
31493117
auto resTy = getTypeConverter()->convertType(op.getType());
3150-
auto llvmOp =
3151-
createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, adaptor.getInput(),
3152-
/*poisonZeroInputFlag=*/true, rewriter);
3118+
auto llvmOp = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
3119+
op.getLoc(), resTy, adaptor.getInput(), op.getIsZeroPoison());
31533120
rewriter.replaceOp(op, llvmOp);
31543121
return mlir::LogicalResult::success();
31553122
}
@@ -3158,9 +3125,8 @@ mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite(
31583125
cir::BitCtzOp op, OpAdaptor adaptor,
31593126
mlir::ConversionPatternRewriter &rewriter) const {
31603127
auto resTy = getTypeConverter()->convertType(op.getType());
3161-
auto llvmOp =
3162-
createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(),
3163-
/*poisonZeroInputFlag=*/true, rewriter);
3128+
auto llvmOp = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
3129+
op.getLoc(), resTy, adaptor.getInput(), op.getIsZeroPoison());
31643130
rewriter.replaceOp(op, llvmOp);
31653131
return mlir::LogicalResult::success();
31663132
}
@@ -3169,9 +3135,8 @@ mlir::LogicalResult CIRToLLVMBitFfsOpLowering::matchAndRewrite(
31693135
cir::BitFfsOp op, OpAdaptor adaptor,
31703136
mlir::ConversionPatternRewriter &rewriter) const {
31713137
auto resTy = getTypeConverter()->convertType(op.getType());
3172-
auto ctz =
3173-
createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(),
3174-
/*poisonZeroInputFlag=*/false, rewriter);
3138+
auto ctz = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
3139+
op.getLoc(), resTy, adaptor.getInput(), false);
31753140

31763141
auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
31773142
auto ctzAddOne = rewriter.create<mlir::LLVM::AddOp>(op.getLoc(), ctz, one);
@@ -3196,9 +3161,8 @@ mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite(
31963161
cir::BitParityOp op, OpAdaptor adaptor,
31973162
mlir::ConversionPatternRewriter &rewriter) const {
31983163
auto resTy = getTypeConverter()->convertType(op.getType());
3199-
auto popcnt =
3200-
createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(),
3201-
/*poisonZeroInputFlag=*/std::nullopt, rewriter);
3164+
auto popcnt = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
3165+
adaptor.getInput());
32023166

32033167
auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
32043168
auto popcntMod2 =
@@ -3212,9 +3176,8 @@ mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite(
32123176
cir::BitPopcountOp op, OpAdaptor adaptor,
32133177
mlir::ConversionPatternRewriter &rewriter) const {
32143178
auto resTy = getTypeConverter()->convertType(op.getType());
3215-
auto llvmOp =
3216-
createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(),
3217-
/*poisonZeroInputFlag=*/std::nullopt, rewriter);
3179+
auto llvmOp = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
3180+
adaptor.getInput());
32183181
rewriter.replaceOp(op, llvmOp);
32193182
return mlir::LogicalResult::success();
32203183
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,6 @@ mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp(
7373
const llvm::Twine &intrinsicName, mlir::Type resultTy,
7474
mlir::ValueRange operands);
7575

76-
mlir::Value createLLVMBitOp(mlir::Location loc,
77-
const llvm::Twine &llvmIntrinBaseName,
78-
mlir::Type resultTy, mlir::Value operand,
79-
std::optional<bool> poisonZeroInputFlag,
80-
mlir::ConversionPatternRewriter &rewriter);
81-
8276
class CIRToLLVMCopyOpLowering : public mlir::OpConversionPattern<cir::CopyOp> {
8377
public:
8478
using mlir::OpConversionPattern<cir::CopyOp>::OpConversionPattern;

0 commit comments

Comments
 (0)