Skip to content

[DAGCombiner] Use getShiftAmountConstant where possible. #97683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 37 additions & 55 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4395,14 +4395,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
// fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
unsigned Log2Val = (-ConstValue1).logBase2();
EVT ShiftVT = getShiftAmountTy(N0.getValueType());

// FIXME: If the input is something that is easily negated (e.g. a
// single-use add), we should put the negate there.
return Matcher.getNode(
ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
Matcher.getNode(ISD::SHL, DL, VT, N0,
DAG.getConstant(Log2Val, DL, ShiftVT)));
DAG.getShiftAmountConstant(Log2Val, VT, DL)));
}

// Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
Expand Down Expand Up @@ -5101,9 +5100,9 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {

// fold (mulhs x, 1) -> (sra x, size(x)-1)
if (isOneConstant(N1))
return DAG.getNode(ISD::SRA, DL, VT, N0,
DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
getShiftAmountTy(VT)));
return DAG.getNode(
ISD::SRA, DL, VT, N0,
DAG.getShiftAmountConstant(N0.getScalarValueSizeInBits() - 1, VT, DL));

// fold (mulhs x, undef) -> 0
if (N0.isUndef() || N1.isUndef())
Expand All @@ -5121,8 +5120,7 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
DAG.getConstant(SimpleSize, DL,
getShiftAmountTy(N1.getValueType())));
DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
}
}
Expand Down Expand Up @@ -5192,8 +5190,7 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
DAG.getConstant(SimpleSize, DL,
getShiftAmountTy(N1.getValueType())));
DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
}
}
Expand Down Expand Up @@ -5404,8 +5401,7 @@ SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
// Compute the high part as N1.
Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
DAG.getConstant(SimpleSize, DL,
getShiftAmountTy(Lo.getValueType())));
DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
// Compute the low part as N0.
Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
Expand Down Expand Up @@ -5458,8 +5454,7 @@ SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
// Compute the high part as N1.
Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
DAG.getConstant(SimpleSize, DL,
getShiftAmountTy(Lo.getValueType())));
DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
// Compute the low part as N0.
Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
Expand Down Expand Up @@ -7484,8 +7479,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
if (OpSizeInBits > 16) {
SDLoc DL(N);
Res = DAG.getNode(ISD::SRL, DL, VT, Res,
DAG.getConstant(OpSizeInBits - 16, DL,
getShiftAmountTy(VT)));
DAG.getShiftAmountConstant(OpSizeInBits - 16, VT, DL));
}
return Res;
}
Expand Down Expand Up @@ -7603,7 +7597,7 @@ static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
// (rotr (bswap A), 16)
static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
SelectionDAG &DAG, SDNode *N, SDValue N0,
SDValue N1, EVT VT, EVT ShiftAmountTy) {
SDValue N1, EVT VT) {
assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
"MatchBSwapHWordOrAndAnd: expecting i32");
if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
Expand Down Expand Up @@ -7635,7 +7629,7 @@ static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,

SDLoc DL(N);
SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
}

Expand All @@ -7655,13 +7649,11 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
return SDValue();

if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
getShiftAmountTy(VT)))
if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
return BSwap;

// Try again with commuted operands.
if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
getShiftAmountTy(VT)))
if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT))
return BSwap;


Expand Down Expand Up @@ -7698,7 +7690,7 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {

// Result of the bswap should be rotated by 16. If it's not legal, then
// do (x << 16) | (x >> 16).
SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
Expand Down Expand Up @@ -10430,8 +10422,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
TLI.isTruncateFree(VT, TruncVT)) {
SDValue Amt = DAG.getConstant(ShiftAmt, DL,
getShiftAmountTy(N0.getOperand(0).getValueType()));
SDValue Amt = DAG.getShiftAmountConstant(ShiftAmt, VT, DL);
SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
N0.getOperand(0), Amt);
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
Expand Down Expand Up @@ -10679,10 +10670,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
uint64_t ShiftAmt = N1C->getZExtValue();
SDLoc DL0(N0);
SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
N0.getOperand(0),
DAG.getConstant(ShiftAmt, DL0,
getShiftAmountTy(SmallVT)));
SDValue SmallShift =
DAG.getNode(ISD::SRL, DL0, SmallVT, N0.getOperand(0),
DAG.getShiftAmountConstant(ShiftAmt, SmallVT, DL0));
AddToWorklist(SmallShift.getNode());
APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
return DAG.getNode(ISD::AND, DL, VT,
Expand Down Expand Up @@ -10726,8 +10716,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
if (ShAmt) {
SDLoc DL(N0);
Op = DAG.getNode(ISD::SRL, DL, VT, Op,
DAG.getConstant(ShAmt, DL,
getShiftAmountTy(Op.getValueType())));
DAG.getShiftAmountConstant(ShAmt, VT, DL));
AddToWorklist(Op.getNode());
}
return DAG.getNode(ISD::XOR, DL, VT, Op, DAG.getConstant(1, DL, VT));
Expand Down Expand Up @@ -11086,7 +11075,7 @@ SDValue DAGCombiner::visitBSWAP(SDNode *N) {
SDValue Res = N0.getOperand(0);
if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
Res = DAG.getNode(ISD::SHL, DL, VT, Res,
DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
DAG.getShiftAmountConstant(NewShAmt, VT, DL));
Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
return DAG.getZExtOrTrunc(Res, DL, VT);
Expand Down Expand Up @@ -12316,9 +12305,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
return DAG.getNode(ISD::ABS, DL, VT, LHS);

SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
DAG.getConstant(VT.getScalarSizeInBits() - 1,
DL, getShiftAmountTy(VT)));
SDValue Shift = DAG.getNode(
ISD::SRA, DL, VT, LHS,
DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, DL));
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
AddToWorklist(Shift.getNode());
AddToWorklist(Add.getNode());
Expand Down Expand Up @@ -14625,18 +14614,15 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
// Shift the result left, if we've swallowed a left shift.
SDValue Result = Load;
if (ShLeftAmt != 0) {
EVT ShImmTy = getShiftAmountTy(Result.getValueType());
if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
ShImmTy = VT;
// If the shift amount is as large as the result size (but, presumably,
// no larger than the source) then the useful bits of the result are
// zero; we can't simply return the shortened shift, because the result
// of that operation is undefined.
if (ShLeftAmt >= VT.getScalarSizeInBits())
Result = DAG.getConstant(0, DL, VT);
else
Result = DAG.getNode(ISD::SHL, DL, VT,
Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
Result = DAG.getNode(ISD::SHL, DL, VT, Result,
DAG.getShiftAmountConstant(ShLeftAmt, VT, DL));
}

if (ShiftedOffset != 0) {
Expand Down Expand Up @@ -16898,7 +16884,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {

// Perform actual transform.
SDValue MantissaShiftCnt =
DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT));
DAG.getShiftAmountConstant(*Mantissa, NewIntVT, DL);
// TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
// `(X << C1) + (C << C1)`, but that isn't always the case because of the
// cast. We could implement that by handle here to handle the casts.
Expand Down Expand Up @@ -19811,9 +19797,9 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
// shifted by ByteShift and truncated down to NumBytes.
if (ByteShift) {
SDLoc DL(IVal);
IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
DAG.getConstant(ByteShift*8, DL,
DC->getShiftAmountTy(IVal.getValueType())));
IVal = DAG.getNode(
ISD::SRL, DL, IVal.getValueType(), IVal,
DAG.getShiftAmountConstant(ByteShift * 8, IVal.getValueType(), DL));
}

// Figure out the offset for the store and the alignment of the access.
Expand Down Expand Up @@ -27422,12 +27408,11 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,

// and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
// constant.
EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
AddToWorklist(Shift.getNode());

Expand All @@ -27447,7 +27432,7 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
if (TLI.shouldAvoidTransformToShift(XType, ShCt))
return SDValue();

SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
AddToWorklist(Shift.getNode());

Expand Down Expand Up @@ -27661,16 +27646,13 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
const APInt &AndMask = ConstAndRHS->getAPIntValue();
if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
unsigned ShCt = AndMask.getBitWidth() - 1;
SDValue ShlAmt =
DAG.getConstant(AndMask.countl_zero(), SDLoc(AndLHS),
getShiftAmountTy(AndLHS.getValueType()));
SDValue ShlAmt = DAG.getShiftAmountConstant(AndMask.countl_zero(), VT,
SDLoc(AndLHS));
SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);

// Now arithmetic right shift it all the way over, so the result is
// either all-ones, or zero.
SDValue ShrAmt =
DAG.getConstant(ShCt, SDLoc(Shl),
getShiftAmountTy(Shl.getValueType()));
SDValue ShrAmt = DAG.getShiftAmountConstant(ShCt, VT, SDLoc(Shl));
SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);

return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
Expand Down Expand Up @@ -27718,9 +27700,9 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
return SDValue();

// shl setcc result by log2 n2c
return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
DAG.getConstant(ShCt, SDLoc(Temp),
getShiftAmountTy(Temp.getValueType())));
return DAG.getNode(
ISD::SHL, DL, N2.getValueType(), Temp,
DAG.getShiftAmountConstant(ShCt, N2.getValueType(), SDLoc(Temp)));
}

// select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
Expand Down
10 changes: 4 additions & 6 deletions llvm/test/CodeGen/X86/shift-combine.ll
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,10 @@ define i64 @ashr_add_neg_shl_i32(i64 %r) nounwind {
define i64 @ashr_add_neg_shl_i8(i64 %r) nounwind {
; X86-LABEL: ashr_add_neg_shl_i8:
; X86: # %bb.0:
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-NEXT: shll $24, %eax
; X86-NEXT: movl $33554432, %edx # imm = 0x2000000
; X86-NEXT: subl %eax, %edx
; X86-NEXT: movl %edx, %eax
; X86-NEXT: sarl $24, %eax
; X86-NEXT: movb $2, %al
; X86-NEXT: subb {{[0-9]+}}(%esp), %al
; X86-NEXT: movsbl %al, %eax
; X86-NEXT: movl %eax, %edx
; X86-NEXT: sarl $31, %edx
; X86-NEXT: retl
;
Expand Down
Loading