Skip to content

Commit 43e357f

Browse files
committed
[X86] Update sra(x,umin(amt,bw-1)) -> psrav(x,amt) fold to use SDPatternMatch. NFC.
First tentative attempt to use SDPatternMatch for x86 combine matching - main problem so far is namespace clashing when trying to expose llvm::SDPatternMatch to the entire file.
1 parent 3b362ee commit 43e357f

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "llvm/CodeGen/MachineLoopInfo.h"
3939
#include "llvm/CodeGen/MachineModuleInfo.h"
4040
#include "llvm/CodeGen/MachineRegisterInfo.h"
41+
#include "llvm/CodeGen/SDPatternMatch.h"
4142
#include "llvm/CodeGen/TargetLowering.h"
4243
#include "llvm/CodeGen/WinEHFuncInfo.h"
4344
#include "llvm/IR/CallingConv.h"
@@ -48084,22 +48085,22 @@ static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
4808448085

4808548086
static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
4808648087
const X86Subtarget &Subtarget) {
48088+
using namespace llvm::SDPatternMatch;
4808748089
SDValue N0 = N->getOperand(0);
4808848090
SDValue N1 = N->getOperand(1);
4808948091
EVT VT = N0.getValueType();
4809048092
unsigned Size = VT.getSizeInBits();
48093+
SDLoc DL(N);
4809148094

4809248095
if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
4809348096
return V;
4809448097

48095-
APInt ShiftAmt;
48096-
if (supportedVectorVarShift(VT, Subtarget, ISD::SRA) &&
48097-
N1.getOpcode() == ISD::UMIN &&
48098-
ISD::isConstantSplatVector(N1.getOperand(1).getNode(), ShiftAmt) &&
48099-
ShiftAmt == VT.getScalarSizeInBits() - 1) {
48100-
SDValue ShrAmtVal = N1.getOperand(0);
48101-
SDLoc DL(N);
48102-
return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
48098+
// fold sra(x,umin(amt,bw-1)) -> avx2 psrav(x,amt)
48099+
if (supportedVectorVarShift(VT, Subtarget, ISD::SRA)) {
48100+
SDValue ShrAmtVal;
48101+
if (sd_match(N1, m_UMin(m_Value(ShrAmtVal),
48102+
m_SpecificInt(VT.getScalarSizeInBits() - 1))))
48103+
return DAG.getNode(X86ISD::VSRAV, DL, VT, N0, ShrAmtVal);
4810348104
}
4810448105

4810548106
// fold (SRA (SHL X, ShlConst), SraConst)
@@ -48137,7 +48138,6 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
4813748138
// Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
4813848139
if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
4813948140
continue;
48140-
SDLoc DL(N);
4814148141
SDValue NN =
4814248142
DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
4814348143
if (SraConst.eq(ShlConst))

0 commit comments

Comments
 (0)