Skip to content

[DAGCombiner] Use getShiftAmountConstant where possible. #97683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 4, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Jul 4, 2024

In #97645, I proposed removing the LegalTypes operand to TargetLowering::getShiftAmountTy. This means we don't need to use the DAGCombiner wrapper for getShiftAmountTy that manages this flag. Now we can use getShiftAmountConstant and let it call TargetLowering::getShiftAmountTy.

This contains the same X86 test change as #97645.

In llvm#97645, I proposed removing the LegalTypes operand to
TargetLowering::getShiftAmountTy. This means we don't need to
use the DAGCombiner wrapper for getShiftAmountTy that manages this
flag. Now we can use getShiftAmountOrConstant and let it call
TargetLowering::getShiftAmountTy.

This contains the same X86 test change as llvm#97645.
@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels Jul 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2024

@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-llvm-selectiondag

Author: Craig Topper (topperc)

Changes

In #97645, I proposed removing the LegalTypes operand to TargetLowering::getShiftAmountTy. This means we don't need to use the DAGCombiner wrapper for getShiftAmountTy that manages this flag. Now we can use getShiftAmountOrConstant and let it call TargetLowering::getShiftAmountTy.

This contains the same X86 test change as #97645.


Full diff: https://github.com/llvm/llvm-project/pull/97683.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+37-55)
  • (modified) llvm/test/CodeGen/X86/shift-combine.ll (+4-6)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d81a54d2ecaaa..a7b469ea17f18 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4395,14 +4395,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
   if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
     unsigned Log2Val = (-ConstValue1).logBase2();
-    EVT ShiftVT = getShiftAmountTy(N0.getValueType());
 
     // FIXME: If the input is something that is easily negated (e.g. a
     // single-use add), we should put the negate there.
     return Matcher.getNode(
         ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
         Matcher.getNode(ISD::SHL, DL, VT, N0,
-                        DAG.getConstant(Log2Val, DL, ShiftVT)));
+                        DAG.getShiftAmountConstant(Log2Val, VT, DL)));
   }
 
   // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
@@ -5101,9 +5100,9 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
 
   // fold (mulhs x, 1) -> (sra x, size(x)-1)
   if (isOneConstant(N1))
-    return DAG.getNode(ISD::SRA, DL, VT, N0,
-                       DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
-                                       getShiftAmountTy(VT)));
+    return DAG.getNode(
+        ISD::SRA, DL, VT, N0,
+        DAG.getShiftAmountConstant(N0.getScalarValueSizeInBits() - 1, VT, DL));
 
   // fold (mulhs x, undef) -> 0
   if (N0.isUndef() || N1.isUndef())
@@ -5121,8 +5120,7 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
-            DAG.getConstant(SimpleSize, DL,
-                            getShiftAmountTy(N1.getValueType())));
+                       DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
     }
   }
@@ -5192,8 +5190,7 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
-            DAG.getConstant(SimpleSize, DL,
-                            getShiftAmountTy(N1.getValueType())));
+                       DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
     }
   }
@@ -5404,8 +5401,7 @@ SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
       // Compute the high part as N1.
       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
-            DAG.getConstant(SimpleSize, DL,
-                            getShiftAmountTy(Lo.getValueType())));
+                       DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
       // Compute the low part as N0.
       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
@@ -5458,8 +5454,7 @@ SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
       // Compute the high part as N1.
       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
-            DAG.getConstant(SimpleSize, DL,
-                            getShiftAmountTy(Lo.getValueType())));
+                       DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
       // Compute the low part as N0.
       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
@@ -7484,8 +7479,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
   if (OpSizeInBits > 16) {
     SDLoc DL(N);
     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
-                      DAG.getConstant(OpSizeInBits - 16, DL,
-                                      getShiftAmountTy(VT)));
+                      DAG.getShiftAmountConstant(OpSizeInBits - 16, VT, DL));
   }
   return Res;
 }
@@ -7603,7 +7597,7 @@ static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
 //   (rotr (bswap A), 16)
 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
-                                       SDValue N1, EVT VT, EVT ShiftAmountTy) {
+                                       SDValue N1, EVT VT) {
   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
          "MatchBSwapHWordOrAndAnd: expecting i32");
   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
@@ -7635,7 +7629,7 @@ static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
 
   SDLoc DL(N);
   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
-  SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
+  SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
 }
 
@@ -7655,13 +7649,11 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
     return SDValue();
 
-  if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
-                                              getShiftAmountTy(VT)))
+  if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
     return BSwap;
 
   // Try again with commuted operands.
-  if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
-                                              getShiftAmountTy(VT)))
+  if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT))
     return BSwap;
 
 
@@ -7698,7 +7690,7 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
 
   // Result of the bswap should be rotated by 16. If it's not legal, then
   // do  (x << 16) | (x >> 16).
-  SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
+  SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
@@ -10430,8 +10422,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
           TLI.isTruncateFree(VT, TruncVT)) {
-        SDValue Amt = DAG.getConstant(ShiftAmt, DL,
-            getShiftAmountTy(N0.getOperand(0).getValueType()));
+        SDValue Amt = DAG.getShiftAmountConstant(ShiftAmt, VT, DL);
         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
                                     N0.getOperand(0), Amt);
         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
@@ -10679,10 +10670,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
       uint64_t ShiftAmt = N1C->getZExtValue();
       SDLoc DL0(N0);
-      SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
-                                       N0.getOperand(0),
-                          DAG.getConstant(ShiftAmt, DL0,
-                                          getShiftAmountTy(SmallVT)));
+      SDValue SmallShift =
+          DAG.getNode(ISD::SRL, DL0, SmallVT, N0.getOperand(0),
+                      DAG.getShiftAmountConstant(ShiftAmt, SmallVT, DL0));
       AddToWorklist(SmallShift.getNode());
       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
       return DAG.getNode(ISD::AND, DL, VT,
@@ -10726,8 +10716,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
       if (ShAmt) {
         SDLoc DL(N0);
         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
-                  DAG.getConstant(ShAmt, DL,
-                                  getShiftAmountTy(Op.getValueType())));
+                         DAG.getShiftAmountConstant(ShAmt, VT, DL));
         AddToWorklist(Op.getNode());
       }
       return DAG.getNode(ISD::XOR, DL, VT, Op, DAG.getConstant(1, DL, VT));
@@ -11086,7 +11075,7 @@ SDValue DAGCombiner::visitBSWAP(SDNode *N) {
       SDValue Res = N0.getOperand(0);
       if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
         Res = DAG.getNode(ISD::SHL, DL, VT, Res,
-                          DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
+                          DAG.getShiftAmountConstant(NewShAmt, VT, DL));
       Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
       Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
       return DAG.getZExtOrTrunc(Res, DL, VT);
@@ -12316,9 +12305,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
         return DAG.getNode(ISD::ABS, DL, VT, LHS);
 
-      SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
-                                  DAG.getConstant(VT.getScalarSizeInBits() - 1,
-                                                  DL, getShiftAmountTy(VT)));
+      SDValue Shift = DAG.getNode(
+          ISD::SRA, DL, VT, LHS,
+          DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, DL));
       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
       AddToWorklist(Shift.getNode());
       AddToWorklist(Add.getNode());
@@ -14625,9 +14614,6 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
   // Shift the result left, if we've swallowed a left shift.
   SDValue Result = Load;
   if (ShLeftAmt != 0) {
-    EVT ShImmTy = getShiftAmountTy(Result.getValueType());
-    if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
-      ShImmTy = VT;
     // If the shift amount is as large as the result size (but, presumably,
     // no larger than the source) then the useful bits of the result are
     // zero; we can't simply return the shortened shift, because the result
@@ -14635,8 +14621,8 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
     if (ShLeftAmt >= VT.getScalarSizeInBits())
       Result = DAG.getConstant(0, DL, VT);
     else
-      Result = DAG.getNode(ISD::SHL, DL, VT,
-                          Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
+      Result = DAG.getNode(ISD::SHL, DL, VT, Result,
+                           DAG.getShiftAmountConstant(ShLeftAmt, VT, DL));
   }
 
   if (ShiftedOffset != 0) {
@@ -16898,7 +16884,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
 
   // Perform actual transform.
   SDValue MantissaShiftCnt =
-      DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT));
+      DAG.getShiftAmountConstant(*Mantissa, NewIntVT, DL);
   // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
   // `(X << C1) + (C << C1)`, but that isn't always the case because of the
   // cast. We could implement that by handle here to handle the casts.
@@ -19811,9 +19797,9 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
   // shifted by ByteShift and truncated down to NumBytes.
   if (ByteShift) {
     SDLoc DL(IVal);
-    IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
-                       DAG.getConstant(ByteShift*8, DL,
-                                    DC->getShiftAmountTy(IVal.getValueType())));
+    IVal = DAG.getNode(
+        ISD::SRL, DL, IVal.getValueType(), IVal,
+        DAG.getShiftAmountConstant(ByteShift * 8, IVal.getValueType(), DL));
   }
 
   // Figure out the offset for the store and the alignment of the access.
@@ -27422,12 +27408,11 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
 
   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
   // constant.
-  EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
-      SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
+      SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
       AddToWorklist(Shift.getNode());
 
@@ -27447,7 +27432,7 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
     return SDValue();
 
-  SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
+  SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
   AddToWorklist(Shift.getNode());
 
@@ -27661,16 +27646,13 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
       const APInt &AndMask = ConstAndRHS->getAPIntValue();
       if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
         unsigned ShCt = AndMask.getBitWidth() - 1;
-        SDValue ShlAmt =
-            DAG.getConstant(AndMask.countl_zero(), SDLoc(AndLHS),
-                            getShiftAmountTy(AndLHS.getValueType()));
+        SDValue ShlAmt = DAG.getShiftAmountConstant(AndMask.countl_zero(), VT,
+                                                    SDLoc(AndLHS));
         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
 
         // Now arithmetic right shift it all the way over, so the result is
         // either all-ones, or zero.
-        SDValue ShrAmt =
-          DAG.getConstant(ShCt, SDLoc(Shl),
-                          getShiftAmountTy(Shl.getValueType()));
+        SDValue ShrAmt = DAG.getShiftAmountConstant(ShCt, VT, SDLoc(Shl));
         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
 
         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
@@ -27718,9 +27700,9 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
       return SDValue();
 
     // shl setcc result by log2 n2c
-    return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
-                       DAG.getConstant(ShCt, SDLoc(Temp),
-                                       getShiftAmountTy(Temp.getValueType())));
+    return DAG.getNode(
+        ISD::SHL, DL, N2.getValueType(), Temp,
+        DAG.getShiftAmountConstant(ShCt, N2.getValueType(), SDLoc(Temp)));
   }
 
   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
diff --git a/llvm/test/CodeGen/X86/shift-combine.ll b/llvm/test/CodeGen/X86/shift-combine.ll
index 30c3d53dd37c9..c9edd3f3e9048 100644
--- a/llvm/test/CodeGen/X86/shift-combine.ll
+++ b/llvm/test/CodeGen/X86/shift-combine.ll
@@ -444,12 +444,10 @@ define i64 @ashr_add_neg_shl_i32(i64 %r) nounwind {
 define i64 @ashr_add_neg_shl_i8(i64 %r) nounwind {
 ; X86-LABEL: ashr_add_neg_shl_i8:
 ; X86:       # %bb.0:
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    shll $24, %eax
-; X86-NEXT:    movl $33554432, %edx # imm = 0x2000000
-; X86-NEXT:    subl %eax, %edx
-; X86-NEXT:    movl %edx, %eax
-; X86-NEXT:    sarl $24, %eax
+; X86-NEXT:    movb $2, %al
+; X86-NEXT:    subb {{[0-9]+}}(%esp), %al
+; X86-NEXT:    movsbl %al, %eax
+; X86-NEXT:    movl %eax, %edx
 ; X86-NEXT:    sarl $31, %edx
 ; X86-NEXT:    retl
 ;

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc topperc merged commit 34fe032 into llvm:main Jul 4, 2024
10 checks passed
@topperc topperc deleted the pr/dag-combiner-getshiftamountconstant branch July 4, 2024 15:44
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
In llvm#97645, I proposed removing the LegalTypes operand to
TargetLowering::getShiftAmountTy. This means we don't need to use the
DAGCombiner wrapper for getShiftAmountTy that manages this flag. Now we
can use getShiftAmountConstant and let it call
TargetLowering::getShiftAmountTy.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants