Skip to content

Commit 61a4e1e

Browse files
authored
[DAG] Add SDPatternMatch::m_SetCC and update some combines to use it (#98646)
The plan is to add more TernaryOp in the future (SELECT/VSELECT and FMA in particular)
1 parent 33af112 commit 61a4e1e

File tree

3 files changed

+92
-31
lines changed

3 files changed

+92
-31
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,49 @@ template <> struct EffectiveOperands<false> {
447447
explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
448448
};
449449

450+
// === Ternary operations ===
451+
template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
452+
bool ExcludeChain = false>
453+
struct TernaryOpc_match {
454+
unsigned Opcode;
455+
T0_P Op0;
456+
T1_P Op1;
457+
T2_P Op2;
458+
459+
TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
460+
const T2_P &Op2)
461+
: Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
462+
463+
template <typename MatchContext>
464+
bool match(const MatchContext &Ctx, SDValue N) {
465+
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
466+
EffectiveOperands<ExcludeChain> EO(N);
467+
assert(EO.Size == 3);
468+
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
469+
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
470+
(Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
471+
Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) &&
472+
Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2));
473+
}
474+
475+
return false;
476+
}
477+
};
478+
479+
template <typename T0_P, typename T1_P, typename T2_P>
480+
inline TernaryOpc_match<T0_P, T1_P, T2_P, false, false>
481+
m_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
482+
return TernaryOpc_match<T0_P, T1_P, T2_P, false, false>(ISD::SETCC, Op0, Op1,
483+
Op2);
484+
}
485+
486+
template <typename T0_P, typename T1_P, typename T2_P>
487+
inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
488+
m_c_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
489+
return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, Op0, Op1,
490+
Op2);
491+
}
492+
450493
// === Binary operations ===
451494
template <typename LHS_P, typename RHS_P, bool Commutable = false,
452495
bool ExcludeChain = false>

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,24 +2300,12 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
23002300
return true;
23012301
}
23022302

2303-
if (N.getOpcode() != ISD::SETCC ||
2304-
N.getValueType().getScalarType() != MVT::i1 ||
2305-
cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
2306-
return false;
2307-
2308-
SDValue Op0 = N->getOperand(0);
2309-
SDValue Op1 = N->getOperand(1);
2310-
assert(Op0.getValueType() == Op1.getValueType());
2311-
2312-
if (isNullOrNullSplat(Op0))
2313-
Op = Op1;
2314-
else if (isNullOrNullSplat(Op1))
2315-
Op = Op0;
2316-
else
2303+
if (N.getValueType().getScalarType() != MVT::i1 ||
2304+
!sd_match(
2305+
N, m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
23172306
return false;
23182307

23192308
Known = DAG.computeKnownBits(Op);
2320-
23212309
return (Known.Zero | 1).isAllOnes();
23222310
}
23232311

@@ -2544,26 +2532,22 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
25442532
return SDValue();
25452533

25462534
// Match the zext operand as a setcc of a boolean.
2547-
if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
2548-
Z.getOperand(0).getValueType() != MVT::i1)
2535+
if (Z.getOperand(0).getValueType() != MVT::i1)
25492536
return SDValue();
25502537

25512538
// Match the compare as: setcc (X & 1), 0, eq.
2552-
SDValue SetCC = Z.getOperand(0);
2553-
ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
2554-
if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
2555-
SetCC.getOperand(0).getOpcode() != ISD::AND ||
2556-
!isOneConstant(SetCC.getOperand(0).getOperand(1)))
2539+
if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
2540+
m_SpecificCondCode(ISD::SETEQ))))
25572541
return SDValue();
25582542

25592543
// We are adding/subtracting a constant and an inverted low bit. Turn that
25602544
// into a subtract/add of the low bit with incremented/decremented constant:
25612545
// add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
25622546
// sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
25632547
EVT VT = C.getValueType();
2564-
SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
2565-
SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
2566-
DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
2548+
SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
2549+
SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
2550+
: DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
25672551
return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
25682552
}
25692553

@@ -11554,13 +11538,12 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
1155411538
SDValue N1 = N->getOperand(1);
1155511539
SDValue N2 = N->getOperand(2);
1155611540
EVT VT = N->getValueType(0);
11557-
if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
11558-
return SDValue();
1155911541

11560-
SDValue Cond0 = N0.getOperand(0);
11561-
SDValue Cond1 = N0.getOperand(1);
11562-
ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
11563-
if (VT != Cond0.getValueType())
11542+
SDValue Cond0, Cond1;
11543+
ISD::CondCode CC;
11544+
if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
11545+
m_CondCode(CC)))) ||
11546+
VT != Cond0.getValueType())
1156411547
return SDValue();
1156511548

1156611549
// Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,41 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
119119
EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
120120
}
121121

122+
TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
123+
SDLoc DL;
124+
auto Int32VT = EVT::getIntegerVT(Context, 32);
125+
126+
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
127+
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
128+
129+
SDValue ICMP_UGT = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETUGT);
130+
SDValue ICMP_EQ01 = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETEQ);
131+
SDValue ICMP_EQ10 = DAG->getSetCC(DL, MVT::i1, Op1, Op0, ISD::SETEQ);
132+
133+
using namespace SDPatternMatch;
134+
ISD::CondCode CC;
135+
EXPECT_TRUE(sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(),
136+
m_SpecificCondCode(ISD::SETUGT))));
137+
EXPECT_TRUE(
138+
sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_CondCode(CC))));
139+
EXPECT_TRUE(CC == ISD::SETUGT);
140+
EXPECT_FALSE(sd_match(
141+
ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_SpecificCondCode(ISD::SETLE))));
142+
143+
EXPECT_TRUE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op0), m_Specific(Op1),
144+
m_SpecificCondCode(ISD::SETEQ))));
145+
EXPECT_TRUE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op1), m_Specific(Op0),
146+
m_SpecificCondCode(ISD::SETEQ))));
147+
EXPECT_FALSE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op1), m_Specific(Op0),
148+
m_SpecificCondCode(ISD::SETEQ))));
149+
EXPECT_FALSE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op0), m_Specific(Op1),
150+
m_SpecificCondCode(ISD::SETEQ))));
151+
EXPECT_TRUE(sd_match(ICMP_EQ01, m_c_SetCC(m_Specific(Op1), m_Specific(Op0),
152+
m_SpecificCondCode(ISD::SETEQ))));
153+
EXPECT_TRUE(sd_match(ICMP_EQ10, m_c_SetCC(m_Specific(Op0), m_Specific(Op1),
154+
m_SpecificCondCode(ISD::SETEQ))));
155+
}
156+
122157
TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
123158
SDLoc DL;
124159
auto Int32VT = EVT::getIntegerVT(Context, 32);

0 commit comments

Comments
 (0)