diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index f0e5a7d393b6c..47b1cc1ba6460 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask, return true; } +static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) { + // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) + // This would be benefit for the cases where X and Y are both the same value + // type of low precision vectors. Since the truncate would be lowered into + // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate + // restriction, such pattern would be expanded into a series of "vsetvli" + // and "vnsrl" instructions later to reach this point. + auto IsTruncNode = [](SDValue V) { + if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) + return false; + SDValue VL = V.getOperand(2); + auto *C = dyn_cast(VL); + // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand + bool IsVLMAXForVMSET = (C && C->isAllOnes()) || + (isa(VL) && + cast(VL)->getReg() == RISCV::X0); + return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET; + }; + + SDValue Op = N->getOperand(0); + + // We need to first find the inner level of TRUNCATE_VECTOR_VL node + // to distinguish such pattern. + while (IsTruncNode(Op)) { + if (!Op.hasOneUse()) + return SDValue(); + Op = Op.getOperand(0); + } + + if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse()) + return SDValue(); + + SDValue N0 = Op.getOperand(0); + SDValue N1 = Op.getOperand(1); + if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() || + N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse()) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + SDValue N10 = N1.getOperand(0); + if (!N00.getValueType().isVector() || + N00.getValueType() != N10.getValueType() || + N->getValueType(0) != N10.getValueType()) + return SDValue(); + + unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1; + SDValue SMin = + DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10, + DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0))); + return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin); +} SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { @@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } } return SDValue(); - case RISCVISD::TRUNCATE_VECTOR_VL: { - // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) - // This would be benefit for the cases where X and Y are both the same value - // type of low precision vectors. Since the truncate would be lowered into - // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate - // restriction, such pattern would be expanded into a series of "vsetvli" - // and "vnsrl" instructions later to reach this point. - auto IsTruncNode = [](SDValue V) { - if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) - return false; - SDValue VL = V.getOperand(2); - auto *C = dyn_cast(VL); - // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand - bool IsVLMAXForVMSET = (C && C->isAllOnes()) || - (isa(VL) && - cast(VL)->getReg() == RISCV::X0); - return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && - IsVLMAXForVMSET; - }; - - SDValue Op = N->getOperand(0); - - // We need to first find the inner level of TRUNCATE_VECTOR_VL node - // to distinguish such pattern. - while (IsTruncNode(Op)) { - if (!Op.hasOneUse()) - return SDValue(); - Op = Op.getOperand(0); - } - - if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) { - SDValue N0 = Op.getOperand(0); - SDValue N1 = Op.getOperand(1); - if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() && - N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) { - SDValue N00 = N0.getOperand(0); - SDValue N10 = N1.getOperand(0); - if (N00.getValueType().isVector() && - N00.getValueType() == N10.getValueType() && - N->getValueType(0) == N10.getValueType()) { - unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1; - SDValue SMin = DAG.getNode( - ISD::SMIN, SDLoc(N1), N->getValueType(0), N10, - DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0))); - return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin); - } - } - } - break; - } + case RISCVISD::TRUNCATE_VECTOR_VL: + return combineTruncOfSraSext(N, DAG); case ISD::TRUNCATE: return performTRUNCATECombine(N, DAG, Subtarget); case ISD::SELECT: