@@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
16087
16087
return true;
16088
16088
}
16089
16089
16090
+ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
16091
+ // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
16092
+ // This would be benefit for the cases where X and Y are both the same value
16093
+ // type of low precision vectors. Since the truncate would be lowered into
16094
+ // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
16095
+ // restriction, such pattern would be expanded into a series of "vsetvli"
16096
+ // and "vnsrl" instructions later to reach this point.
16097
+ auto IsTruncNode = [](SDValue V) {
16098
+ if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
16099
+ return false;
16100
+ SDValue VL = V.getOperand(2);
16101
+ auto *C = dyn_cast<ConstantSDNode>(VL);
16102
+ // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
16103
+ bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
16104
+ (isa<RegisterSDNode>(VL) &&
16105
+ cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
16106
+ return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
16107
+ };
16108
+
16109
+ SDValue Op = N->getOperand(0);
16110
+
16111
+ // We need to first find the inner level of TRUNCATE_VECTOR_VL node
16112
+ // to distinguish such pattern.
16113
+ while (IsTruncNode(Op)) {
16114
+ if (!Op.hasOneUse())
16115
+ return SDValue();
16116
+ Op = Op.getOperand(0);
16117
+ }
16118
+
16119
+ if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
16120
+ return SDValue();
16121
+
16122
+ SDValue N0 = Op.getOperand(0);
16123
+ SDValue N1 = Op.getOperand(1);
16124
+ if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
16125
+ N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
16126
+ return SDValue();
16127
+
16128
+ SDValue N00 = N0.getOperand(0);
16129
+ SDValue N10 = N1.getOperand(0);
16130
+ if (!N00.getValueType().isVector() ||
16131
+ N00.getValueType() != N10.getValueType() ||
16132
+ N->getValueType(0) != N10.getValueType())
16133
+ return SDValue();
16134
+
16135
+ unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
16136
+ SDValue SMin =
16137
+ DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
16138
+ DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
16139
+ return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
16140
+ }
16090
16141
16091
16142
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
16092
16143
DAGCombinerInfo &DCI) const {
@@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
16304
16355
}
16305
16356
}
16306
16357
return SDValue();
16307
- case RISCVISD::TRUNCATE_VECTOR_VL: {
16308
- // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
16309
- // This would be benefit for the cases where X and Y are both the same value
16310
- // type of low precision vectors. Since the truncate would be lowered into
16311
- // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
16312
- // restriction, such pattern would be expanded into a series of "vsetvli"
16313
- // and "vnsrl" instructions later to reach this point.
16314
- auto IsTruncNode = [](SDValue V) {
16315
- if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
16316
- return false;
16317
- SDValue VL = V.getOperand(2);
16318
- auto *C = dyn_cast<ConstantSDNode>(VL);
16319
- // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
16320
- bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
16321
- (isa<RegisterSDNode>(VL) &&
16322
- cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
16323
- return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
16324
- IsVLMAXForVMSET;
16325
- };
16326
-
16327
- SDValue Op = N->getOperand(0);
16328
-
16329
- // We need to first find the inner level of TRUNCATE_VECTOR_VL node
16330
- // to distinguish such pattern.
16331
- while (IsTruncNode(Op)) {
16332
- if (!Op.hasOneUse())
16333
- return SDValue();
16334
- Op = Op.getOperand(0);
16335
- }
16336
-
16337
- if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
16338
- SDValue N0 = Op.getOperand(0);
16339
- SDValue N1 = Op.getOperand(1);
16340
- if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
16341
- N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
16342
- SDValue N00 = N0.getOperand(0);
16343
- SDValue N10 = N1.getOperand(0);
16344
- if (N00.getValueType().isVector() &&
16345
- N00.getValueType() == N10.getValueType() &&
16346
- N->getValueType(0) == N10.getValueType()) {
16347
- unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
16348
- SDValue SMin = DAG.getNode(
16349
- ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
16350
- DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
16351
- return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
16352
- }
16353
- }
16354
- }
16355
- break;
16356
- }
16358
+ case RISCVISD::TRUNCATE_VECTOR_VL:
16359
+ return combineTruncOfSraSext(N, DAG);
16357
16360
case ISD::TRUNCATE:
16358
16361
return performTRUNCATECombine(N, DAG, Subtarget);
16359
16362
case ISD::SELECT:
0 commit comments