Skip to content

Commit 060b302

Browse files
authored
[RISCV] Move TRUNCATE_VECTOR_VL combine into a helper function. NFC (#93574)
I plan to add other combines on TRUNCATE_VECTOR_VL.
1 parent e3f74d4 commit 060b302

File tree

1 file changed

+53
-50
lines changed

1 file changed

+53
-50
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
1608716087
return true;
1608816088
}
1608916089

16090+
static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
16091+
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
16092+
// This would be benefit for the cases where X and Y are both the same value
16093+
// type of low precision vectors. Since the truncate would be lowered into
16094+
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
16095+
// restriction, such pattern would be expanded into a series of "vsetvli"
16096+
// and "vnsrl" instructions later to reach this point.
16097+
auto IsTruncNode = [](SDValue V) {
16098+
if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
16099+
return false;
16100+
SDValue VL = V.getOperand(2);
16101+
auto *C = dyn_cast<ConstantSDNode>(VL);
16102+
// Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
16103+
bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
16104+
(isa<RegisterSDNode>(VL) &&
16105+
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
16106+
return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
16107+
};
16108+
16109+
SDValue Op = N->getOperand(0);
16110+
16111+
// We need to first find the inner level of TRUNCATE_VECTOR_VL node
16112+
// to distinguish such pattern.
16113+
while (IsTruncNode(Op)) {
16114+
if (!Op.hasOneUse())
16115+
return SDValue();
16116+
Op = Op.getOperand(0);
16117+
}
16118+
16119+
if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
16120+
return SDValue();
16121+
16122+
SDValue N0 = Op.getOperand(0);
16123+
SDValue N1 = Op.getOperand(1);
16124+
if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
16125+
N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
16126+
return SDValue();
16127+
16128+
SDValue N00 = N0.getOperand(0);
16129+
SDValue N10 = N1.getOperand(0);
16130+
if (!N00.getValueType().isVector() ||
16131+
N00.getValueType() != N10.getValueType() ||
16132+
N->getValueType(0) != N10.getValueType())
16133+
return SDValue();
16134+
16135+
unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
16136+
SDValue SMin =
16137+
DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
16138+
DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
16139+
return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
16140+
}
1609016141

1609116142
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1609216143
DAGCombinerInfo &DCI) const {
@@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1630416355
}
1630516356
}
1630616357
return SDValue();
16307-
case RISCVISD::TRUNCATE_VECTOR_VL: {
16308-
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
16309-
// This would be benefit for the cases where X and Y are both the same value
16310-
// type of low precision vectors. Since the truncate would be lowered into
16311-
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
16312-
// restriction, such pattern would be expanded into a series of "vsetvli"
16313-
// and "vnsrl" instructions later to reach this point.
16314-
auto IsTruncNode = [](SDValue V) {
16315-
if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
16316-
return false;
16317-
SDValue VL = V.getOperand(2);
16318-
auto *C = dyn_cast<ConstantSDNode>(VL);
16319-
// Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
16320-
bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
16321-
(isa<RegisterSDNode>(VL) &&
16322-
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
16323-
return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
16324-
IsVLMAXForVMSET;
16325-
};
16326-
16327-
SDValue Op = N->getOperand(0);
16328-
16329-
// We need to first find the inner level of TRUNCATE_VECTOR_VL node
16330-
// to distinguish such pattern.
16331-
while (IsTruncNode(Op)) {
16332-
if (!Op.hasOneUse())
16333-
return SDValue();
16334-
Op = Op.getOperand(0);
16335-
}
16336-
16337-
if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
16338-
SDValue N0 = Op.getOperand(0);
16339-
SDValue N1 = Op.getOperand(1);
16340-
if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
16341-
N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
16342-
SDValue N00 = N0.getOperand(0);
16343-
SDValue N10 = N1.getOperand(0);
16344-
if (N00.getValueType().isVector() &&
16345-
N00.getValueType() == N10.getValueType() &&
16346-
N->getValueType(0) == N10.getValueType()) {
16347-
unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
16348-
SDValue SMin = DAG.getNode(
16349-
ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
16350-
DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
16351-
return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
16352-
}
16353-
}
16354-
}
16355-
break;
16356-
}
16358+
case RISCVISD::TRUNCATE_VECTOR_VL:
16359+
return combineTruncOfSraSext(N, DAG);
1635716360
case ISD::TRUNCATE:
1635816361
return performTRUNCATECombine(N, DAG, Subtarget);
1635916362
case ISD::SELECT:

0 commit comments

Comments
 (0)